<a href="https://colab.research.google.com/github/surriu111/cclab/blob/main/attention%2Bperception_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
# Mount to Google Drive & Switch to the dataset directory

from google.colab import drive
drive.mount('/content/drive')

import os

os.listdir('/content/drive/MyDrive/ml-com')

import torch
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

import torch
import torch.nn as nn
import torch.nn.functional as F



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


['types2label_128_pruned7.txt',
 'train_labels_128.npy',
 'demo.ipynb',
 'label2type.txt',
 'model.py',
 'train.npz',
 'Untitled folder']

In [17]:
import torch
import random
import numpy as np

# Set a fixed random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [18]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, npz_path):
        npz_data = np.load(npz_path)
        self.images = npz_data["images"] # (N, 3, 128, 128) in np.uint8
        self.labels = npz_data["labels"] # (N,) in np.int64
        assert self.images.shape[0] == self.labels.shape[0]
        print(f"{npz_path}: images shape {self.images.shape}, "
              f"labels shape {self.labels.shape}")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = torch.tensor(self.images[idx]) / 255 # convert to [0, 1] range
        label = torch.tensor(self.labels[idx])
        return image, label

npz_path = '/content/drive/MyDrive/ml-com/train.npz'
train_dataset = CustomDataset(npz_path)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

/content/drive/MyDrive/ml-com/train.npz: images shape (18900, 3, 128, 128), labels shape (18900,)


In [4]:
# sample data batch
images, labels = next(iter(train_loader))
print(f"images shape: {images.shape}")
print(f"labels shape: {labels.shape}")

images shape: torch.Size([128, 3, 128, 128])
labels shape: torch.Size([128])


In [32]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()
        q = self.query(x).view(B, -1, H * W).permute(0, 2, 1)   # B x HW x C'
        k = self.key(x).view(B, -1, H * W)                      # B x C' x HW
        attn = torch.bmm(q, k)                                 # B x HW x HW
        attn = torch.softmax(attn, dim=-1)
        v = self.value(x).view(B, -1, H * W)                   # B x C x HW
        out = torch.bmm(v, attn.permute(0, 2, 1))              # B x C x HW
        out = out.view(B, C, H, W)
        return self.gamma * out + x


In [None]:
class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(latent_dim * 32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, z):
        return self.fc(z)


In [22]:
class ConvVAE(nn.Module):
    def __init__(self, input_channels=3, latent_channels=8):
        super().__init__()

        # --- Encoder ---
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1),  # 64 x 64 x 64
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),             # 128 x 32 x 32
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),            # 256 x 32 x 32
            nn.ReLU()
        )

        # --- Attention ---
        self.attn = SelfAttention(256)

        # --- Quantization conv ---
        self.quant_conv = nn.Conv2d(256, latent_channels * 2, kernel_size=1)    # 4+4 channels

        # --- Decoder ---
        self.post_quant_conv = nn.Conv2d(latent_channels, 256, kernel_size=1)

        self.decoder_attn = SelfAttention(256)

        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),            # 128 x 32 x 32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),    # 64 x 64 x 64
            nn.ReLU(),
            nn.ConvTranspose2d(64, input_channels, kernel_size=4, stride=2, padding=1), # 3 x 128 x 128
            nn.Sigmoid() # Predict within value range [0, 1]
        )

    def preprocess(self, x):
        return x

    def vae_encode(self, x):
        x = self.preprocess(x)
        h = self.encoder(x)
        h = self.attn(h)
        h = self.quant_conv(h)
        mean, logvar = torch.chunk(h, 2, dim=1)
        if self.training:
            logvar = logvar.clamp(-30.0, 20.0)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mean + eps * std
        else:
            z = mean
        return z, mean, logvar

    def encode(self, x):
        x = self.preprocess(x)
        h = self.encoder(x)
        h = self.attn(h)
        h = self.quant_conv(h)
        mean, logvar = torch.chunk(h, 2, dim=1)
        if self.training:
            logvar = logvar.clamp(-30.0, 20.0)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mean + eps * std
        else:
            z = mean
        return z

    def decode(self, z):
      h = self.post_quant_conv(z)
      h = self.decoder_attn(h)
      x_recon = self.decoder(h)
      return x_recon

    def forward(self, x):
        z, mean, logvar = self.vae_encode(x)
        x_recon = self.decode(z)
        return x_recon, z, mean, logvar


In [None]:
class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(latent_dim * 32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, z):
        return self.fc(z)


In [38]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- VGG16 Feature Extractor (for perceptual loss) ---
vgg = models.vgg16(pretrained=True).features[:16].eval().to('cuda')
for p in vgg.parameters():
    p.requires_grad = False

def perceptual_loss_fn(x_recon, x_true):
    recon_feat = vgg(x_recon)
    true_feat = vgg(x_true)
    return F.mse_loss(recon_feat, true_feat)

# --- ConvVAE Model Definition ---
class ConvVAE(nn.Module):
    def __init__(self, input_channels=3, latent_channels=8):
        super().__init__()

        # --- Encoder ---
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1),  # 64 x 64 x 64
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),             # 128 x 32 x 32
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),            # 256 x 32 x 32
            nn.ReLU()
        )

        # --- Attention ---
        self.attn = SelfAttention(256)

        # --- Quantization conv ---
        self.quant_conv = nn.Conv2d(256, latent_channels * 2, kernel_size=1)    # 4+4 channels

        # --- Decoder ---
        self.post_quant_conv = nn.Conv2d(latent_channels, 256, kernel_size=1)

        self.decoder_attn = SelfAttention(256)

        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),            # 128 x 32 x 32
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),    # 64 x 64 x 64
            nn.ReLU(),
            nn.ConvTranspose2d(64, input_channels, kernel_size=4, stride=2, padding=1), # 3 x 128 x 128
            nn.Sigmoid() # Predict within value range [0, 1]
        )

    def preprocess(self, x):
        return x

    def vae_encode(self, x):
        x = self.preprocess(x)
        h = self.encoder(x)
        h = self.attn(h)
        h = self.quant_conv(h)
        mean, logvar = torch.chunk(h, 2, dim=1)
        if self.training:
            logvar = logvar.clamp(-30.0, 20.0)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mean + eps * std
        else:
            z = mean
        return z, mean, logvar

    def encode(self, x):
        x = self.preprocess(x)
        h = self.encoder(x)
        h = self.attn(h)
        h = self.quant_conv(h)
        mean, logvar = torch.chunk(h, 2, dim=1)
        if self.training:
            logvar = logvar.clamp(-30.0, 20.0)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mean + eps * std
        else:
            z = mean
        return z

    def decode(self, z):
      h = self.post_quant_conv(z)
      h = self.decoder_attn(h)
      x_recon = self.decoder(h)
      return x_recon

    def forward(self, x):
        z, mean, logvar = self.vae_encode(x)
        x_recon = self.decode(z)
        return x_recon, z, mean, logvar

# ----- Loss Function -----
def vae_loss(x, x_recon, mean, logvar, vgg, kl_weight=0.1, perceptual_weight=0.1):
    recon_loss = F.mse_loss(x, x_recon, reduction='mean')
    kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())

    # 计算 perceptual loss
    perceptual_loss = perceptual_loss_fn(x_recon, x)

    # 总损失
    final_loss = recon_loss + kl_weight * kl_loss + perceptual_weight * perceptual_loss
    return final_loss, recon_loss, kl_loss, perceptual_loss

# ----- Training -----
def train_vae(model, dataloader, optimizer, device, vgg, num_epochs=1, perceptual_weight=0.1):
    model.train()
    for epoch in range(num_epochs):
        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for images, _ in loop:
            x = images.to(device)
            optimizer.zero_grad()
            x_recon, z, mean, logvar = model(x)
            loss, recon_loss, kl_loss, perceptual_loss = vae_loss(x, x_recon, mean, logvar, vgg, perceptual_weight=perceptual_weight)
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item(), recon=recon_loss.item(), kl=kl_loss.item(), perceptual=perceptual_loss.item())
         # 更新学习率

# ---- Setup ----

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "T4 GPU")

# Model
model = ConvVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train
train_vae(model, train_loader, optimizer, device, vgg, num_epochs=15)

Epoch 1/15: 100%|██████████| 591/591 [02:00<00:00,  4.90it/s, kl=0.239, loss=0.183, perceptual=1.5, recon=0.00905]
Epoch 2/15: 100%|██████████| 591/591 [02:00<00:00,  4.89it/s, kl=0.249, loss=0.129, perceptual=0.978, recon=0.0059]
Epoch 3/15: 100%|██████████| 591/591 [02:00<00:00,  4.89it/s, kl=0.273, loss=0.124, perceptual=0.911, recon=0.00543]
Epoch 4/15: 100%|██████████| 591/591 [02:00<00:00,  4.89it/s, kl=0.316, loss=0.127, perceptual=0.89, recon=0.00629]
Epoch 5/15: 100%|██████████| 591/591 [02:00<00:00,  4.89it/s, kl=0.345, loss=0.129, perceptual=0.887, recon=0.00548]
Epoch 6/15: 100%|██████████| 591/591 [02:00<00:00,  4.89it/s, kl=0.33, loss=0.122, perceptual=0.834, recon=0.00583]
Epoch 7/15: 100%|██████████| 591/591 [02:00<00:00,  4.90it/s, kl=0.322, loss=0.0988, perceptual=0.624, recon=0.00425]
Epoch 8/15: 100%|██████████| 591/591 [02:00<00:00,  4.89it/s, kl=0.367, loss=0.103, perceptual=0.621, recon=0.00437]
Epoch 9/15: 100%|██████████| 591/591 [02:00<00:00,  4.90it/s, kl=0.4

In [24]:
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
# 在初始化外部写一次
import torchvision.models as models
vgg = models.vgg16(pretrained=True).features[:16].eval().to('T4 GPU')
for p in vgg.parameters():
    p.requires_grad = False

def perceptual_loss_fn(x_recon, x_true):
    recon_feat = vgg(x_recon)
    true_feat = vgg(x_true)
    return F.mse_loss(recon_feat, true_feat)
# ----- Loss Function -----
def vae_loss(x, x_recon, mean, logvar, kl_weight=0.1):
    recon_loss = F.mse_loss(x, x_recon, reduction='mean')
    kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
    final_loss = recon_loss + kl_weight * kl_loss
    return final_loss, recon_loss, kl_loss

# ----- Training -----
def train_vae(model, dataloader, optimizer, device, num_epochs=1):
    model.train()
    for epoch in range(num_epochs):
        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        # loop = dataloader
        for images, labels in loop:
            x, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            x_recon, z, mean, logvar = model(x)
            loss, recon_loss, kl_loss = vae_loss(x, x_recon, mean, logvar)
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item(), recon=recon_loss.item(), kl=kl_loss.item())

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "T4 GPU")

# Model
model = ConvVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train
train_vae(model,train_loader, optimizer, device, num_epochs=20)

Epoch 1/20: 100%|██████████| 591/591 [01:00<00:00,  9.81it/s, kl=0.0604, loss=0.0182, recon=0.0121]
Epoch 2/20: 100%|██████████| 591/591 [01:01<00:00,  9.64it/s, kl=0.0561, loss=0.0149, recon=0.00928]
Epoch 3/20: 100%|██████████| 591/591 [01:01<00:00,  9.64it/s, kl=0.0664, loss=0.0186, recon=0.012]
Epoch 4/20: 100%|██████████| 591/591 [01:01<00:00,  9.66it/s, kl=0.0696, loss=0.0188, recon=0.0118]
Epoch 5/20: 100%|██████████| 591/591 [01:01<00:00,  9.64it/s, kl=0.0664, loss=0.0166, recon=0.00999]
Epoch 6/20: 100%|██████████| 591/591 [01:01<00:00,  9.65it/s, kl=0.0535, loss=0.0129, recon=0.00756]
Epoch 7/20: 100%|██████████| 591/591 [01:01<00:00,  9.66it/s, kl=0.0629, loss=0.0157, recon=0.00941]
Epoch 8/20: 100%|██████████| 591/591 [01:01<00:00,  9.65it/s, kl=0.065, loss=0.0159, recon=0.00939]
Epoch 9/20: 100%|██████████| 591/591 [01:01<00:00,  9.65it/s, kl=0.0682, loss=0.0175, recon=0.0107]
Epoch 10/20: 100%|██████████| 591/591 [01:01<00:00,  9.63it/s, kl=0.0651, loss=0.0153, recon=0.00

In [1]:
# ----- Visualization -----
def plot_reconstructions(model, dataloader, device, num_images=8):
    model.eval()
    with torch.no_grad():
        x = next(iter(dataloader))[0].to(device)
        x_recon, z, _, _ = model(x)
        x = x.cpu().numpy()
        x_recon = x_recon.cpu().numpy()
        print(f"Latent bottleneck dimension: {z.flatten(start_dim=1).shape[1]}")
        print(x_recon[0, :, 64, 64])

        plt.figure(figsize=(16, 4))
        for i in range(num_images):
            # Original
            plt.subplot(2, num_images, i+1)
            plt.imshow(x[i].transpose(1, 2, 0))  # (C, H, W) -> (H, W, C)
            plt.axis('off')

            # Reconstruction
            plt.subplot(2, num_images, i+1+num_images)
            plt.imshow(x_recon[i].transpose(1, 2, 0))
            plt.axis('off')

        plt.show()

plot_reconstructions(model, train_loader, device, num_images=8)

NameError: name 'model' is not defined

In [43]:
# Submission

# 1) Save model weights
torch.save(model.state_dict(), "checkpoint6.pt")

# 2) Prepare the 'Model' class for submission
with open("/content/drive/MyDrive/ml-com/model.py", "r") as f:
    print(f.read())

# 3) Submit the model code & weights online

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, input_channels=3, latent_channels=4):
        super().__init__()
        # Make sure the layers are consistent with your checkpoint weights
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.quant_conv = nn.Conv2d(256, latent_channels * 2, kernel_size=1)
        self.post_quant_conv = nn.Conv2d(latent_channels, 256, kernel_size=1)
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose

In [None]:
# Metric 1: Recon MSE on test set (on value range [0, 1])
# Metric 2: Classification accuracy (linear probing with test set latents, 170 classes)
# Final Score: recon_mse / probing_accuracy (the lower the better)

In [None]:

import importlib
import torch
import numpy as np

def load_model(model_path, weights_path):
    spec = importlib.util.spec_from_file_location("model_module", model_path)
    model_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(model_module)  # Load the module
    model = model_module.Model()  # Create an instance of the model class
    model.load_state_dict(torch.load(weights_path))  # Load weights
    print("model loaded successfully")

    # try small data on cpu to check if the model is loaded correctly
    test_data = np.random.rand(3, 3, 128, 128)  # Example input data
    test_data = torch.tensor(test_data, dtype=torch.float32)
    test_data = test_data.to("cpu")
    model = model.to("cpu")
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():
        output = model.encode(test_data)
        output = model.decode(output)
    print("Model loaded successfully and output generated.")

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np

# model_e and model_d are two identical instances of your model class
def run_inference_AE(test_data_numpy, test_label_numpy, num_classes,
                  model_e, model_d, gpu_index,
                  batch_size=64, timeout=50, bottleNeckDim = 8192):
    device = torch.device(f"cuda:{gpu_index}")
    print(f"Using device: {device}")
    model_e.to(device)  # Move the model to the GPU
    model_e.eval()  # Set the model to evaluation mode
    model_d.to(device)  # Move the model to the GPU
    model_d.eval()  # Set the model to evaluation mode

    # build test dataloader from the numpy array
    test_data = torch.tensor(test_data_numpy, dtype=torch.float32)
    test_labels = torch.tensor(test_label_numpy, dtype=torch.long)
    test_dataset = torch.utils.data.TensorDataset(test_data, test_labels)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    all_latents = []
    criterion = nn.MSELoss(reduction='sum')
    reconstruction_loss = 0
    shape_checked = False
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader, start=1):
            images, labels = images.to(device), labels.to(device)  # Move data to the GPU
            latents = model_e.encode(images)

            print("latents shape:", latents.shape)
            # check latents shape not too large
            if not shape_checked:
                latents_orig_shape = latents.shape
                latents = latents.view(latents.shape[0], -1)
                if latents.shape[1] > bottleNeckDim:
                    raise ValueError(f"Latents shape is too large: {latents.shape}. Expected less than {bottleNeckDim}.")
                latents = latents.view(latents_orig_shape)
                shape_checked = True

            outputs = model_d.decode(latents)

            # compute reconstruction loss
            loss = criterion(outputs, images)
            reconstruction_loss += loss.item()

            all_latents.append(latents.cpu().numpy())

        reconstruction_loss = reconstruction_loss / len(test_loader.dataset)

        # sample images from the latent space
        # mean and std of all_latents
        all_latents = np.concatenate(all_latents, axis=0)
        mean_latents = np.mean(all_latents, axis=0)
        std_latents = np.std(all_latents, axis=0)

        # sample 5 random latents
        random_latents = np.random.normal(mean_latents, std_latents, (all_latents[:5].shape))
        # reconstruct the images from the latents
        random_latents = torch.tensor(random_latents, dtype=torch.float32).to(device)
        sampled_images = model_d.decode(random_latents)
        sampled_images = sampled_images.cpu().numpy()
        # save the reconstructed images, optional

    # release gpu memory
    torch.cuda.empty_cache()
