In [33]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchinfo import summary

from datasets import HousingDataset
from utils import Select, CustomScale

In [34]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Running on device: {device}")

Running on device: cpu


In [35]:
mean=np.array([1377, 1354, 1381, 2356])
std=np.array([540, 398, 327, 515])

upperbound = mean + 3 * std
SCALE = upperbound[:, np.newaxis, np.newaxis]

norm_mean = mean / upperbound
norm_std = std / upperbound
    
transformations = [
    transforms.CenterCrop(size=(32, 32)), 
    CustomScale(scale=1/SCALE, clamp=(0, 1.0)),
    transforms.Normalize(mean=norm_mean, std=norm_std),
    Select(dim=-3, indices=[0,1,2]),
]
transform = transforms.Compose(transformations)

reverse_transform = transforms.Compose([
    transforms.Normalize(mean=-norm_mean[:3]/norm_std[:3], std=1/norm_std[:3]), 
    CustomScale(scale=SCALE[:3], clamp=None),
])

In [36]:
train_set = HousingDataset(transform=transform)

In [37]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True)

In [38]:
batch = next(iter(train_loader))

In [23]:
encoder = torchvision.models.resnet18(pretrained=True).to(device=device)
encoder.layer4 = torch.nn.Identity()
encoder.avgpool = torch.nn.Identity()
encoder.fc = torch.nn.Identity()

In [24]:
summary(encoder, input_size=(4, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   --                        --
├─Conv2d: 1-1                            [4, 64, 16, 16]           9,408
├─BatchNorm2d: 1-2                       [4, 64, 16, 16]           128
├─ReLU: 1-3                              [4, 64, 16, 16]           --
├─MaxPool2d: 1-4                         [4, 64, 8, 8]             --
├─Sequential: 1-5                        [4, 64, 8, 8]             --
│    └─BasicBlock: 2-1                   [4, 64, 8, 8]             --
│    │    └─Conv2d: 3-1                  [4, 64, 8, 8]             36,864
│    │    └─BatchNorm2d: 3-2             [4, 64, 8, 8]             128
│    │    └─ReLU: 3-3                    [4, 64, 8, 8]             --
│    │    └─Conv2d: 3-4                  [4, 64, 8, 8]             36,864
│    │    └─BatchNorm2d: 3-5             [4, 64, 8, 8]             128
│    │    └─ReLU: 3-6                    [4, 64, 8, 8]             --
│

In [25]:
class Model(torch.nn.Module):
    def __init__(self, encoder, input_shape=(1,3,32,32), proj_features=256):
        super().__init__()
        self.encoder = encoder

        dummy_input = torch.zeros(input_shape)
        output_shape = self.encoder(dummy_input).shape
        
        in_features = np.prod(output_shape[1:])
        self.start_proj = torch.nn.Sequential(
            torch.nn.Linear(in_features, proj_features),
            torch.nn.ReLU()
        )
        self.end_proj = torch.nn.Sequential(
            torch.nn.Linear(in_features, proj_features),
            torch.nn.ReLU()
        )
        self.sample_proj = torch.nn.Sequential(
            torch.nn.Linear(in_features, proj_features),
            torch.nn.ReLU()
        )
        self.clf = torch.nn.Sequential(
            torch.nn.Linear(3 * proj_features, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1),
        )
        
    def forward(self, img_start, img_end, img_sample):
        """
        Input:
            x (FloatTensor) - Batch of input images of shape (B, length, C, H, W)
        Output:
            FloatTensor of shape (B, num_classes)
        """
        B = img_start.shape[0]
        images = torch.cat([img_start, img_end, img_sample], dim=0)
        x_start, x_end, x_sample = self.encoder(images).flatten().reshape((3, B, -1))
        x_start = self.start_proj(x_start)
        x_end = self.end_proj(x_end)
        x_sample = self.sample_proj(x_sample)
        x = torch.cat([x_start, x_end, x_sample], dim=-1)
        x = self.clf(x)
        x = torch.sigmoid(x)
        return x

In [29]:
model = Model(encoder)

In [31]:
summary(model, [(4, 3, 32, 32), (4, 3, 32, 32), (4, 3, 32, 32)])

Layer (type:depth-idx)                        Output Shape              Param #
Model                                         --                        --
├─ResNet: 1-1                                 [12, 1024]                --
│    └─Conv2d: 2-1                            [12, 64, 16, 16]          9,408
│    └─BatchNorm2d: 2-2                       [12, 64, 16, 16]          128
│    └─ReLU: 2-3                              [12, 64, 16, 16]          --
│    └─MaxPool2d: 2-4                         [12, 64, 8, 8]            --
│    └─Sequential: 2-5                        [12, 64, 8, 8]            --
│    │    └─BasicBlock: 3-1                   [12, 64, 8, 8]            73,984
│    │    └─BasicBlock: 3-2                   [12, 64, 8, 8]            73,984
│    └─Sequential: 2-6                        [12, 128, 4, 4]           --
│    │    └─BasicBlock: 3-3                   [12, 128, 4, 4]           230,144
│    │    └─BasicBlock: 3-4                   [12, 128, 4, 4]           295,42

In [32]:
model(batch["image_start"], batch["image_end"], batch["image_sample"])

tensor([[0.4980],
        [0.4915],
        [0.4846],
        [0.4956]], grad_fn=<SigmoidBackward0>)