In [15]:
import torch
import sys
import yaml
from torchvision import transforms, datasets
import torchvision
import numpy as np
from torch.utils.data.dataloader import DataLoader

In [2]:
sys.path.append('../')
from models.model import ResNet18

In [3]:
batch_size = 2

In [4]:
config = yaml.load(open("../config/config.yaml", "r"), Loader=yaml.FullLoader)

In [5]:
data_transforms = torchvision.transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.cifar.CIFAR10(root='../data', train=True, download=False,
                                       transform=data_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          num_workers=0, drop_last=False, shuffle=True)

test_dataset = datasets.cifar.CIFAR10(root='../data', train=False, download=False,
                                       transform=data_transforms)

test_loader = DataLoader(test_dataset, batch_size=batch_size,
                          num_workers=0, drop_last=False, shuffle=True)

In [6]:
device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu'
online_network = ResNet18(**config['network'])

output_feature_dim = online_network.projetion.net[0].in_features

# remove the projection head
encoder = torch.nn.Sequential(*list(online_network.children())[:-1])

In [7]:
# load pre-trained parameters
load_params = torch.load(os.path.join('/model80000.pth'),
                         map_location=torch.device(torch.device(device)))

# if 'online_network_state_dict' in load_params:
#     online_network.load_state_dict(load_params['online_network_state_dict'])
#     print("'online_network_state_dict' parameters successfully loaded.")
online_network = online_network.to(device)

In [8]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.linear(x)

In [9]:
logreg = LogisticRegression(output_feature_dim, 10)
logreg = logreg.to(device)

In [10]:
def get_features_from_encoder(encoder, loader):
    
    x_train = []
    y_train = []

    # get the features from the pre-trained model
    for i, (x, y) in enumerate(loader):
        with torch.no_grad():
            feature_vector = encoder(x).squeeze(3).squeeze(2)
            x_train.append(feature_vector)
            y_train.extend(y.numpy())
            break
            
    x_train = torch.cat(x_train, dim=0)
    y_train = torch.tensor(y_train)
    return x_train, y_train

In [11]:
encoder.eval()
x_train, y_train = get_features_from_encoder(encoder, train_loader)
print("Training data shape:", x_train.shape, y_train.shape)

x_test, y_test = get_features_from_encoder(encoder, test_loader)
print("Testing data shape:", x_test.shape, y_test.shape)

Training data shape: torch.Size([2, 512]) torch.Size([2])
Testing data shape: torch.Size([2, 512]) torch.Size([2])


In [12]:
x_train

tensor([[0.0000, 0.0236, 0.1258,  ..., 0.1410, 0.0195, 0.0094],
        [0.0000, 0.0520, 0.0985,  ..., 0.1189, 0.0240, 0.0099]])

In [13]:
def next_batch(X, y, batch_size):
    for i in range(0, X.shape[0], batch_size):
        X_batch = X[i: i+batch_size] / 255.
        y_batch = y[i: i+batch_size]
        yield X_batch.to(device), y_batch.to(device)

In [27]:
def accuracy(predictions, targets):
    return (predictions == targets).type(torch.float).mean()

In [30]:
optimizer = torch.optim.Adam(logreg.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1):
    for x, y in next_batch(x_train, y_train, batch_size=64):
        x = x.to(device)
        y = y.to(device)
        
        logits = logreg(x)
        predictions = torch.argmax(logits, dim=1)
     
        loss = criterion(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    accuracy_list = []
    for x, y in next_batch(x_test, y_test, batch_size=256):
        x = x.to(device)
        y = y.to(device)
        
        logits = logreg(x)
        predictions = torch.argmax(logits, dim=1)
        
        acc = accuracy(predictions, y)
        accuracy_list.append(acc)
        
    print(f"Testing accuracy: {np.mean(accuracy_list)}")

torch.Size([2, 512])
torch.Size([2])
tensor([[-0.0052,  0.0198,  0.0069, -0.0428,  0.0235,  0.0391,  0.0353, -0.0032,
          0.0304, -0.0142],
        [-0.0051,  0.0200,  0.0074, -0.0426,  0.0239,  0.0390,  0.0357, -0.0031,
          0.0305, -0.0142]], grad_fn=<AddmmBackward>)
Testing accuracy: 0.5
