In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = "/root/tf-logs/resnet_logs.log")

In [None]:
from torch import nn
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pickle
import os
import numpy as np
import torch.nn.functional as F
# 4 covulutions
class Resblock(nn.Module):
    def __init__(self, input_channels, use_1x1conv, conv1_stride):
        super(Resblock, self).__init__()
        out_channel = input_channels * 2 if use_1x1conv else input_channels
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels = input_channels, 
                out_channels = out_channel, 
                kernel_size = (3,3),
                stride = conv1_stride,
                padding = 1,
                bias=False
            ),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = out_channel, 
                out_channels = out_channel, 
                kernel_size = (3,3),
                padding = 1,
                bias=False
            ),
            nn.BatchNorm2d(out_channel)
        )

        self.use1x1conv = use_1x1conv
        
        if self.use1x1conv:
            self.res_path = nn.Sequential(
                nn.Conv2d(
                    in_channels = input_channels, 
                    out_channels = input_channels * 2, 
                    kernel_size = (1, 1),
                    stride = conv1_stride,
                    bias=False
                ),
                nn.BatchNorm2d(out_channel)
            )
        
        self.act = nn.ReLU()
        
    def forward(self, input_image):
        conv_output = self.conv(input_image)
        if self.use1x1conv:
            res_output = self.res_path(input_image)
        else:
            res_output = input_image
        
        return self.act(conv_output + res_output)
            
class resnet18(nn.Module):
    def __init__(self, num_class):
        super(resnet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels = 3, 
                out_channels = 64, 
                kernel_size = 3,
                stride = 1,
                padding = 1
            ),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels = 64,
                out_channels = 64,
                kernel_size = (3,3),
                padding = 1,
                stride = 1
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            #n.MaxPool2d(
            #   kernel_size = 3,
            #   stride = 2,
            #   padding = 1
            #,
            Resblock(
                input_channels = 64, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            ),
            Resblock(
                input_channels = 64, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.conv3 = nn.Sequential(
            Resblock(
                input_channels = 64, 
                 use_1x1conv = True, 
                 conv1_stride = 2
            ),
            Resblock(
                input_channels = 128, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.conv4 = nn.Sequential(
            Resblock(
                input_channels = 128, 
                 use_1x1conv = True, 
                 conv1_stride = 2
            ),
            Resblock(
                input_channels = 256, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.conv5 = nn.Sequential(
            Resblock(
                input_channels = 256, 
                 use_1x1conv = True, 
                 conv1_stride = 2
            ),
            Resblock(
                input_channels = 512, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.fc = nn.Linear(512, 10)
        
    def forward(self, x):
        conv_output = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
        conv_output = F.avg_pool2d(conv_output, 4)
        conv_output = conv_output.view(x.size(0), -1)
        return self.fc(conv_output)
        


In [None]:
class Convblock(nn.Module):
    def __init__(self, input_channels, use_1x1conv, conv1_stride):
        super(Convblock, self).__init__()
        out_channel = input_channels * 2 if use_1x1conv else input_channels
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels = input_channels, 
                out_channels = out_channel, 
                kernel_size = (3,3),
                stride = conv1_stride,
                padding = 1,
                bias=False
            ),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(
                in_channels = out_channel, 
                out_channels = out_channel, 
                kernel_size = (3,3),
                padding = 1,
                bias=False
            ),
            nn.BatchNorm2d(out_channel)
        )

        self.use1x1conv = use_1x1conv
        
        if self.use1x1conv:
            self.res_path = nn.Conv2d(
                    in_channels = input_channels, 
                    out_channels = input_channels * 2, 
                    kernel_size = (1, 1),
                    stride = conv1_stride,
                    bias=False
            )
        
        self.act = nn.ReLU()
        
    def forward(self, input_image):
        conv_output = self.conv(input_image)
        return self.act(conv_output)
            
class fakeresnet18(nn.Module):
    def __init__(self, num_class):
        super(fakeresnet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels = 3, 
                out_channels = 64, 
                kernel_size = 3,
                stride = 1,
                padding = 1
            ),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels = 64,
                out_channels = 64,
                kernel_size = (3,3),
                padding = 1,
                stride = 1
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            #n.MaxPool2d(
            #   kernel_size = 3,
            #   stride = 2,
            #   padding = 1
            #,
            Convblock(
                input_channels = 64, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            ),
            Convblock(
                input_channels = 64, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.conv3 = nn.Sequential(
            Convblock(
                input_channels = 64, 
                 use_1x1conv = True, 
                 conv1_stride = 2
            ),
            Convblock(
                input_channels = 128, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.conv4 = nn.Sequential(
            Convblock(
                input_channels = 128, 
                 use_1x1conv = True, 
                 conv1_stride = 2
            ),
            Convblock(
                input_channels = 256, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.conv5 = nn.Sequential(
            Convblock(
                input_channels = 256, 
                 use_1x1conv = True, 
                 conv1_stride = 2
            ),
            Convblock(
                input_channels = 512, 
                 use_1x1conv = False, 
                 conv1_stride = 1
            )
        )
        
        self.fc = nn.Linear(512, 10)
        
    def forward(self, x):
        conv_output = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))
        conv_output = F.avg_pool2d(conv_output, 4)
        conv_output = conv_output.view(x.size(0), -1)
        return self.fc(conv_output)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pickle
import os
import numpy as np

class CIFAR10Dataset(Dataset):
    def __init__(self, root, train = True):
        data_batch_list = []
        target_list = []
        for file in os.listdir(root):
            filename = os.path.join(root, file)
            with open(filename, 'rb') as fo:
                data_batch = pickle.load(fo, encoding='bytes')
                data_batch_list.append(data_batch[b'data'])
                target_list.extend(data_batch[b'labels'])
            
        self.data = np.vstack(data_batch_list).reshape(-1, 3, 32, 32)
        # self.data = self.data.transpose((0, 2, 3, 1))
        self.num_samples = self.data.shape[0]
        self.targets = target_list
        
    def __getitem__(self, idx):
        return torch.FloatTensor(self.data[idx]), self.targets[idx]
        
    def __len__(self):
        return self.num_samples
        
        
trainset = CIFAR10Dataset(root='./cifar_10/train', train=True)
testset = CIFAR10Dataset(root='./cifar_10/test', train=False)

trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

# 初始化模型、损失函数和优化器
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = resnet18(num_class = 10)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2, weight_decay=5e-4)

# 训练模型
current_iter = 0

num_epochs = 100
for epoch in range(num_epochs):
    # model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        model.train()
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        writer.add_scalar("train_loss",loss.item(),current_iter)
        # print(current_iter)
        current_iter += 1
        if i % 100 == 99:  # 每100个小批量打印一次损失
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
            model.eval()
            with torch.no_grad():
                test_loss = 0
                testloader_len = 0
                for test_data in testloader:
                    inputs, labels = test_data[0].to(device), test_data[1].to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    test_loss += loss.item()
                    # print(loss.item())
                    testloader_len += 1
                print(test_loss / testloader_len)
                writer.add_scalar("test_loss",test_loss / testloader_len,current_iter)
            

classes = [i for i in range(10)]
class_names = {i:i for i in range(10)}
correct = {classname: 0 for classname in classes}
total = {classname: 0 for classname in classes}

model.eval()  # 切换模型为评估模式
with torch.no_grad():
    for data in testloader:
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        
        for i in range(len(labels)):
            label = labels[i]
            correct[class_names[int(label)]] += c[i].item()
            total[class_names[int(label)]] += 1

for classname, correct_count in correct.items():
    accuracy = 100 * correct_count / total[classname]
    print(f'Accuracy for class {classname}: {accuracy:.2f}%')


print('Finished Training')

# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on the test images: %d %%' % (
    100 * correct / total))