In [1]:
import torchvision.models as models
import os
import wandb
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import save_image
import matplotlib.pyplot as plt


In [2]:
resnet18 = models.resnet18(pretrained=False, num_classes=10)
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()
    
        self.adapter= nn.Sequential( nn.Linear(512, 128), nn.ReLU() )
   
        self.conv1 = resnet18.conv1
        self.bn1 = resnet18.bn1
        self.relu = resnet18.relu
        self.maxpool = resnet18.maxpool
        self.layer1 = resnet18.layer1
        self.layer2 = resnet18.layer2
        self.layer3 = resnet18.layer3
        self.layer4 = resnet18.layer4
        self.avgpool = resnet18.avgpool
        self.fc = nn.Sequential( nn.Linear(128, num_classes))
        # print(self)

    def forward(self, x, no_fc=False):
        # See note [TorchScript super()]
        # print(self)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.adapter(x)
        if no_fc:
            return x
        x = self.fc(x)

        return x



In [3]:
if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 32, 32)
    return x


num_epochs = 10
batch_size = 512
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# dataset_train = MNIST('./data', transform=img_transform, download=True,train = True)
# dataset_test = MNIST('./data', transform=img_transform, download=True,train = False)

train_transform = transforms.Compose([
    transforms.RandomCrop(32,padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

dataset_train = CIFAR10('./data', transform=train_transform, download=True,train = True)
dataset_test = CIFAR10('./data', transform=test_transform, download=True,train = False)


# dataloader
train_loader = DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

test_loader = DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)
# dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified


In [4]:
from time import sleep

# in this cell, I want to train the model and save the weights

# device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device=torch.device('cuda'if torch.cuda.is_available() else 'cpu')
print(device)
modelResnet=ResNet18(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(modelResnet.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    print("Epoch: {}".format(epoch))
    for i, (images, labels) in enumerate(train_loader):
        print(i)
        images = images.to(device)
        labels = labels.to(device)

        outputs = modelResnet(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("Epoch: {}/{}".format(epoch, num_epochs - 1))

cuda
Epoch: 0
0
Epoch: 0/9
1
Epoch: 0/9
2
Epoch: 0/9
3
Epoch: 0/9
4
Epoch: 0/9
5
Epoch: 0/9
6
Epoch: 0/9
7
Epoch: 0/9
8
Epoch: 0/9
9
Epoch: 0/9
10
Epoch: 0/9
11
Epoch: 0/9
12
Epoch: 0/9
13
Epoch: 0/9
14
Epoch: 0/9
15
Epoch: 0/9
16
Epoch: 0/9
17
Epoch: 0/9
18
Epoch: 0/9
19
Epoch: 0/9
20
Epoch: 0/9
21
Epoch: 0/9
22
Epoch: 0/9
23
Epoch: 0/9
24
Epoch: 0/9
25
Epoch: 0/9
26
Epoch: 0/9
27
Epoch: 0/9
28
Epoch: 0/9
29
Epoch: 0/9
30
Epoch: 0/9
31
Epoch: 0/9
32
Epoch: 0/9
33
Epoch: 0/9
34
Epoch: 0/9
35
Epoch: 0/9
36
Epoch: 0/9
37
Epoch: 0/9
38
Epoch: 0/9
39
Epoch: 0/9
40
Epoch: 0/9
41
Epoch: 0/9
42
Epoch: 0/9
43
Epoch: 0/9
44
Epoch: 0/9
45
Epoch: 0/9
46
Epoch: 0/9
47
Epoch: 0/9
48
Epoch: 0/9
49
Epoch: 0/9
50
Epoch: 0/9
51
Epoch: 0/9
52
Epoch: 0/9
53
Epoch: 0/9
54
Epoch: 0/9
55
Epoch: 0/9
56
Epoch: 0/9
57
Epoch: 0/9
58
Epoch: 0/9
59
Epoch: 0/9
60
Epoch: 0/9
61
Epoch: 0/9
62
Epoch: 0/9
63
Epoch: 0/9
64
Epoch: 0/9
65
Epoch: 0/9
66
Epoch: 0/9
67
Epoch: 0/9
68
Epoch: 0/9
69
Epoch: 0/9
70
Epoch: 0/9
71

In [5]:
import torch
from torchvision import models
from torchsummary import summary
import torchvision.models as models
#show what the model is made of we can compare it to the original resnet 18
# resnet18 = models.resnet18()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(modelResnet, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [6]:

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = modelResnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
    # save the model
    torch.save(modelResnet.state_dict(), './resnet18.pth')
    # save the image
    img = images[0].cpu()
    img = img.view(1, 3, 32, 32)
    save_image(img, './mlp_img/image_{}.png'.format(epoch))

Accuracy of the model on the 10000 test images: 76.58 %


In [7]:

from linear_classifier import LinearClassifier
#joined model
class JoinedModel(nn.Module):
    def __init__(self,num_classes=10):
        
        super(JoinedModel, self).__init__()
        #uses the resnets weights already trained
        self.resnet = modelResnet
        #classifier parts
        self.classifier = LinearClassifier()
    def forward(self, x):
        x = self.resnet.forward(x,no_fc=True)
        #classifier part
        x = self.classifier(x)
        return x
joined_model = JoinedModel().to(device)

In [11]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(joined_model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [9]:
# print(joined_model)

JoinedModel(
  (resnet): ResNet18(
    (adapter): Sequential(
      (0): Linear(in_features=512, out_features=128, bias=True)
      (1): ReLU()
    )
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), s

In [12]:
#freeze all weights 
for param in joined_model.parameters():
    param.requires_grad = False
#unfreeze the classifier
joined_model.classifier.fc1.weight.requires_grad = True
joined_model.classifier.fc1.bias.requires_grad = True
joined_model.classifier.fc2.weight.requires_grad = True
joined_model.classifier.fc2.bias.requires_grad = True


In [13]:
#verify weight are frozen

for name, param in joined_model.named_parameters():
    print(name, param.requires_grad)


resnet.adapter.0.weight False
resnet.adapter.0.bias False
resnet.conv1.weight False
resnet.bn1.weight False
resnet.bn1.bias False
resnet.layer1.0.conv1.weight False
resnet.layer1.0.bn1.weight False
resnet.layer1.0.bn1.bias False
resnet.layer1.0.conv2.weight False
resnet.layer1.0.bn2.weight False
resnet.layer1.0.bn2.bias False
resnet.layer1.1.conv1.weight False
resnet.layer1.1.bn1.weight False
resnet.layer1.1.bn1.bias False
resnet.layer1.1.conv2.weight False
resnet.layer1.1.bn2.weight False
resnet.layer1.1.bn2.bias False
resnet.layer2.0.conv1.weight False
resnet.layer2.0.bn1.weight False
resnet.layer2.0.bn1.bias False
resnet.layer2.0.conv2.weight False
resnet.layer2.0.bn2.weight False
resnet.layer2.0.bn2.bias False
resnet.layer2.0.downsample.0.weight False
resnet.layer2.0.downsample.1.weight False
resnet.layer2.0.downsample.1.bias False
resnet.layer2.1.conv1.weight False
resnet.layer2.1.bn1.weight False
resnet.layer2.1.bn1.bias False
resnet.layer2.1.conv2.weight False
resnet.layer2.1.bn

In [14]:
#train the joined model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, joined_model.parameters()), lr=learning_rate)


In [16]:


#train the classification layer
for epoch in range(num_epochs):
    print("Epoch: {}".format(epoch))
    for i, (images, labels) in enumerate(train_loader):
        print(i)
        images = images.to(device)
        labels = labels.to(device)

        outputs = joined_model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("Epoch: {}/{}".format(epoch, num_epochs - 1))

Epoch: 0
0
Epoch: 0/9
1
Epoch: 0/9
2
Epoch: 0/9
3
Epoch: 0/9
4
Epoch: 0/9
5
Epoch: 0/9
6
Epoch: 0/9
7
Epoch: 0/9
8
Epoch: 0/9
9
Epoch: 0/9
10
Epoch: 0/9
11
Epoch: 0/9
12
Epoch: 0/9
13
Epoch: 0/9
14
Epoch: 0/9
15
Epoch: 0/9
16
Epoch: 0/9
17
Epoch: 0/9
18
Epoch: 0/9
19
Epoch: 0/9
20
Epoch: 0/9
21
Epoch: 0/9
22
Epoch: 0/9
23
Epoch: 0/9
24
Epoch: 0/9
25
Epoch: 0/9
26
Epoch: 0/9
27
Epoch: 0/9
28
Epoch: 0/9
29
Epoch: 0/9
30
Epoch: 0/9
31
Epoch: 0/9
32
Epoch: 0/9
33
Epoch: 0/9
34
Epoch: 0/9
35
Epoch: 0/9
36
Epoch: 0/9
37
Epoch: 0/9
38
Epoch: 0/9
39
Epoch: 0/9
40
Epoch: 0/9
41
Epoch: 0/9
42
Epoch: 0/9
43
Epoch: 0/9
44
Epoch: 0/9
45
Epoch: 0/9
46
Epoch: 0/9
47
Epoch: 0/9
48
Epoch: 0/9
49
Epoch: 0/9
50
Epoch: 0/9
51
Epoch: 0/9
52
Epoch: 0/9
53
Epoch: 0/9
54
Epoch: 0/9
55
Epoch: 0/9
56
Epoch: 0/9
57
Epoch: 0/9
58
Epoch: 0/9
59
Epoch: 0/9
60
Epoch: 0/9
61
Epoch: 0/9
62
Epoch: 0/9
63
Epoch: 0/9
64
Epoch: 0/9
65
Epoch: 0/9
66
Epoch: 0/9
67
Epoch: 0/9
68
Epoch: 0/9
69
Epoch: 0/9
70
Epoch: 0/9
71
Epoc

In [17]:
# evaluate the joined model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = joined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
    # save the model
    torch.save(joined_model.state_dict(), './joined_model.pth')
    # save the image
    img = images[0].cpu()
    img = img.view(1, 3, 32, 32)
    save_image(img, './mlp_img_resnet/image_{}.png'.format(epoch))


Accuracy of the model on the 10000 test images: 76.93 %


In [19]:
#save the joined_model for resnet supervised encoder
torch.save(joined_model.state_dict(), './saved_models/joined_model_resnet_supervised.pth')
