In [1]:
from torch import nn
from torch import flatten
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torch.optim as opt
from torch.optim.lr_scheduler import StepLR
import os

In [2]:
class patch_embedding(nn.Module) :
  def __init__(self, img_size=224, inchannels=3) :
    super(patch_embedding,self).__init__()
    self.patch_conv2d = nn.Conv2d(in_channels=inchannels, out_channels=768, stride=16, kernel_size=16)
    self.class_token = nn.Parameter(torch.randn(1,768))
    self.position_token = nn.Parameter(torch.zeros(197,768))
    self.dropout = nn.Dropout(0.1)
  def forward(self, x) :#x(1,3,224,224)
    x = self.patch_conv2d(x) #x(1,768,14,14)
    x = flatten(x, start_dim=2, end_dim=3) #x(1,768,196)
    x = x.transpose(1,2) #exchange dim, x(1,196,768)
    class_tokens = self.class_token.expand(x.shape[0],-1,-1)
    x = torch.cat((class_tokens,x),dim=1)# x(1,197,768)
    x = x + self.position_token
    x = self.dropout(x)
    return x

In [3]:
class MLP_block(nn.Module) :
  def __init__(self) :
    super(MLP_block,self).__init__()
    self.MLP_linear_1 = nn.Linear(in_features=768,out_features=3072)
    self.relu = nn.ReLU()
    self.dropout_1 = nn.Dropout(0.1)
    self.MLP_linear_2 = nn.Linear(in_features=3072,out_features=768)
    self.dropout_2 = nn.Dropout(0.1)

  def forward(self,x) :
    x = self.MLP_linear_1(x)
    x = self.relu(x)
    x = self.dropout_1(x)
    x = self.MLP_linear_2(x)
    x = self.dropout_2(x)
    return x

In [4]:
class multi_head_att(nn.Module) :
  def __init__(self) :
    super(multi_head_att,self).__init__()
    self.q = nn.Linear(768,768)
    self.k = nn.Linear(768,768)
    self.v = nn.Linear(768,768)
    self.out = nn.Linear(768,768)
    self.softmax = nn.Softmax(dim=-1)

  def transpose(self,x) :#x(1,197,768)
    x = x.reshape(x.size()[0:2]+(12,64)) #x(1,197,12,64)
    x = x.permute(0,2,1,3) #x(1,12,197,64)
    return x

  def forward(self,x) :
    all_q = self.q(x)
    all_v = self.v(x)
    all_k = self.k(x)
    all_q_trans = self.transpose(all_q)
    all_v_trans = self.transpose(all_v)
    all_k_trans = self.transpose(all_k)
    att_sc = torch.matmul(all_q_trans,all_k_trans.transpose(-1,-2))
    att_sc = att_sc/8 #torch.sqrt(64)
    att_sc = self.softmax(att_sc)
    context_layer = torch.matmul(att_sc,all_v_trans) #(1,12,197,64)
    context_layer = context_layer.permute(0,2,1,3) #(1,197,12,64)
    new_shape = context_layer.size()[0:2]+(768,) #(1,197,768)
    context_layer = context_layer.reshape(*new_shape)
    out = self.out(context_layer)
    return out

In [5]:
class encoder_block(nn.Module) :
  def __init__(self) :
    super(encoder_block,self).__init__()
    self.layernorm_1 = nn.LayerNorm(normalized_shape=(197,768))
    self.layernorm_2 = nn.LayerNorm(normalized_shape=(197,768))
    self.mul_att = multi_head_att()
    self.dropout_1 = nn.Dropout(0.1)
    self.dropout_2 = nn.Dropout(0.1)
    self.MLP = MLP_block()

  def forward(self,x) :
    y = self.layernorm_1(x)
    y = self.mul_att(y)
    y = self.dropout_1(y)
    x = y+x
    y = self.layernorm_2(x)
    y = self.MLP(y)
    y = self.dropout_2(y)
    x = y+x
    return x

In [6]:
class vit_transformer(nn.Module) :
  def __init__(self,nums_classes:int=2) :
    super(vit_transformer,self).__init__()
    self.encoder_blocks = self.__make_layer()
    self.start_blocks = patch_embedding()
    self.layernorm = nn.LayerNorm(normalized_shape=(197,768))
    self.linear_out = nn.Linear(in_features=100,out_features=nums_classes)
    self.pre_logits_linear = nn.Linear(in_features=768,out_features=100)
    self.pre_logits_tanh = nn.Tanh()

  def __make_layer(self,nums_encoder_block=12) :
    layers = []
    for _ in range(nums_encoder_block) :
      layers.append(encoder_block())
    return nn.Sequential(*layers)

  def forward(self,x) :
    x = self.start_blocks(x)
    x = self.encoder_blocks(x)
    x = self.layernorm(x)
    x = x[:,0,:]
    x = self.pre_logits_linear(x)
    x = self.pre_logits_tanh(x)
    x = self.linear_out(x)
    return x

In [7]:
def transform_jpg(path:str) :
  img = Image.open(path).convert("RGB")
  transform = transforms.Compose([
      transforms.RandomHorizontalFlip(p=0.5),
      transforms.Compose([transforms.RandomCrop(224, padding=4)]),
      transforms.ToTensor()
  ])
  img = transform(img).float()
  return img

In [8]:
train_dir = os.path.join("data","train")
test_dir = os.path.join("data","test")
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225] 
train_augs = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Resize(224),
    transforms.Compose([transforms.RandomCrop(224, padding=4)]),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])
test_augs = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])
train_set = datasets.ImageFolder(train_dir, transform=train_augs)
test_set = datasets.ImageFolder(test_dir, transform=test_augs)
batch_size = 32
train_iter = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_iter = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
lr = 0.00005
model = vit_transformer()
model.to(device)
optimizer = opt.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.9)  # 每1个epoch学习率乘以0.9

cuda


In [None]:
criterion = nn.CrossEntropyLoss().to(device)  # 定义损失函数
criterion = criterion.to(device)
num_epochs = 40  # 设置训练轮数
for epoch in range(num_epochs):
  model.train()  # 设置模型为训练模式
  running_loss = 0.0
  for i,(inputs,labels) in enumerate(train_iter):
      if i%100 == 1 :
          print(i)  
      labels2d = torch.tensor([([0,1] if labels[k]==1 else [1,0]) for k in range(len(labels))]).float()
      inputs, labels2d = inputs.to(device), labels2d.to(device)  # 将输入和标签移至设备
      optimizer.zero_grad()  # 清零梯度
      outputs = model(inputs)  # 前向传播
      loss = criterion(outputs, labels2d).to(device)  # 计算损失
      loss.backward()  # 反向传播:
      optimizer.step()  # 更新模型参数
      running_loss += loss.item()  # 累加损失
  # for name,s in model.named_parameters() :
  #     if "class" in name :
  #         print(s)
  scheduler.step()
  model.eval()
  acc = 0
  for j,(inputs,labels) in enumerate(test_iter):
    inputs, labels = inputs.to(device), labels.to(device)  # 将输入和标签移至设备
    outputs = model(inputs)
    #print(outputs)
    acc += (torch.argmax(outputs,dim=1)==labels).sum()
  print(int(acc)/len(test_set),running_loss)
  print(epoch)

1
101
201
301
401
501
601


In [None]:
temp=torch.randn(2,3,224,224)

In [None]:
outputs = model(temp.to(device))

In [15]:
temp

tensor([[[[-1.2070e+00, -1.4405e+00,  1.0942e+00,  ..., -1.7945e+00,
           -1.1536e+00, -5.9671e-01],
          [ 4.0849e-01,  9.6268e-01,  4.1731e-01,  ...,  9.8890e-01,
            2.1587e-01, -6.0898e-01],
          [-8.4422e-01, -1.4986e+00, -1.5211e+00,  ...,  7.7938e-01,
           -2.9462e-01, -1.0463e-01],
          ...,
          [-8.0904e-01,  1.7094e-01, -4.3821e-01,  ..., -8.0661e-01,
            4.2611e-03, -8.7575e-01],
          [ 4.5484e-01, -1.5969e+00, -4.4237e-01,  ...,  2.2141e+00,
           -2.5057e-01,  1.6764e+00],
          [-5.8072e-01, -1.6755e+00, -6.0562e-01,  ..., -1.4687e+00,
           -1.7963e+00,  6.8623e-01]],

         [[-1.3426e+00,  8.0804e-02,  1.1976e+00,  ...,  2.6589e-01,
            6.2099e-01, -5.5215e-01],
          [-1.6689e+00,  5.5142e-01,  2.0722e-01,  ...,  1.1988e+00,
            2.1333e+00, -1.2328e+00],
          [-1.0650e+00, -1.3250e+00,  1.3441e+00,  ..., -7.4189e-01,
           -1.0795e+00, -1.3313e-01],
          ...,
     

In [16]:
outputs

tensor([[  9.7530, -10.6582],
        [  9.7530, -10.6582]], device='cuda:0', grad_fn=<AddmmBackward0>)