## 作业六——多模态模型
本次作业的目的是让同学们体验使用CLIP处理多模态任务，训练一个多模态少样本目标分类模型。

本次作业需要完成的内容：
1. 仔细阅读代码，补全TODO(1)~TODO(3)标记的内容

需要提交的内容
- 补全后的代码。
- 实验报告（一个PDF文件），要求记录实验准确率，并作简要分析总结。

**请注意，提交的.ipynb文件需要能够重启kernel后，通过一次 全部运行 复现出所有的实验结果。请同学们在提交前按这种方式进行一次检查。**
有随机性的部分复现结果存在一些差别是允许的，但请至少保证不要相差太远，保证代码可以正常执行。


In [2]:
import jittor as jt
from jittor import nn
from jittor.transform import CenterCrop, ImageNormalize, Compose, resize
from PIL import Image
import os
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
import numpy as np
from typing import Union, Tuple, List, Optional
import gzip
import html
import os
from functools import partial

from jittor import Var
from jittor.nn import Module, Linear, softmax, linear, matmul_transpose
from jittor.dataset import Dataset

import ftfy
import regex as re
import math

jt.flags.use_cuda = 1

[38;5;2m[i 1229 13:26:32.296452 12 cuda_flags.cc:55] CUDA enabled.[m


In [4]:
# 如果使用 vGPU，那已经下载好了，无需重复下载；否则需要自己下载
!wget -O bpe_simple_vocab_16e6.txt.gz "https://cloud.tsinghua.edu.cn/f/a65c412d818d4e46bd42/?dl=1"
!wget -O dataset.zip "https://cloud.tsinghua.edu.cn/f/1550782fa0f346028a82/?dl=1"
!unzip dataset.zip
!wget -O ViT-B-32.pkl "https://cloud.tsinghua.edu.cn/f/860a9afb257f4d5180be/?dl=1"

--2025-12-29 13:26:55--  https://cloud.tsinghua.edu.cn/f/a65c412d818d4e46bd42/?dl=1
Resolving cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)... 101.6.15.69, 2402:f000:1:402:101:6:15:69
Connecting to cloud.tsinghua.edu.cn (cloud.tsinghua.edu.cn)|101.6.15.69|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cloud.tsinghua.edu.cn/seafhttp/files/16fa6b3f-7e3b-4b55-a2dc-c8faf99726d9/bpe_simple_vocab_16e6.txt.gz [following]
--2025-12-29 13:26:55--  https://cloud.tsinghua.edu.cn/seafhttp/files/16fa6b3f-7e3b-4b55-a2dc-c8faf99726d9/bpe_simple_vocab_16e6.txt.gz
Reusing existing connection to cloud.tsinghua.edu.cn:443.
HTTP request sent, awaiting response... 200 OK
Length: 1356917 (1.3M) [application/octet-stream]
Saving to: ‘bpe_simple_vocab_16e6.txt.gz’


2025-12-29 13:26:55 (6.41 MB/s) - ‘bpe_simple_vocab_16e6.txt.gz’ saved [1356917/1356917]

--2025-12-29 13:26:56--  https://cloud.tsinghua.edu.cn/f/1550782fa0f346028a82/?dl=1
Resolving cloud.tsinghua.edu.cn (

### 自然语言的切分词

在这个部分中，`tokenize`函数能够将`N`个字符串转换为一个大小为`[N, context_length]`的`Var`张量。

In [5]:
def bytes_to_unicode():
    bs = list(range(ord("!"),
                    ord("~") + 1)) + list(range(
                        ord("¡"),
                        ord("¬") + 1)) + list(range(ord("®"),
                                                    ord("ÿ") + 1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()

def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

class SimpleTokenizer(object):
    def __init__(self, bpe_path):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152 - 256 - 2 + 1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v + '</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {
            '<|startoftext|>': '<|startoftext|>',
            '<|endoftext|>': '<|endoftext|>'
        }
        self.pat = re.compile(
            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
            re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + (token[-1] + '</w>', )
        pairs = get_pairs(word)

        if not pairs:
            return token + '</w>'

        while True:
            bigram = min(
                pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word) - 1 and word[
                        i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b]
                            for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token]
                              for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text
                          ]).decode('utf-8',
                                    errors="replace").replace('</w>', ' ')
        return text

_tokenizer = SimpleTokenizer("bpe_simple_vocab_16e6.txt.gz")

def tokenize(texts: Union[str, List[str]],
             context_length: int = 77,
             truncate: bool = False):
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
                  for text in texts]

    result = jt.zeros((len(all_tokens), context_length), dtype=jt.int64)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(
                    f"Input {texts[i]} is too long for context length {context_length}"
                )
        result[i, :len(tokens)] = jt.Var(tokens)

    return result

### 注意力层

这个部分是简化实现的注意力层。

In [6]:
class MultiheadAttention(Module):
    def __init__(self, embed_dim, num_heads) -> None:
        super().__init__()
        self.embed_dim = embed_dim

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.in_proj_weight = jt.empty((3 * embed_dim, embed_dim), dtype=jt.float32)
        self.in_proj_bias = jt.empty(3 * embed_dim, dtype=jt.float32)

        self.out_proj = Linear(embed_dim, embed_dim)

    def execute(self, x: Var, attn_mask: Optional[Var] = None) -> Var:
        # set up shape vars
        tgt_len, bsz, embed_dim = x.shape
        num_heads = self.num_heads
        head_dim = embed_dim // num_heads
        assert embed_dim == self.embed_dim

        proj = matmul_transpose(x, self.in_proj_weight) + self.in_proj_bias
        proj = jt.reshape(proj, (tgt_len, bsz, 3, embed_dim))
        q, k, v = proj[:, :, 0, :], proj[:, :, 1, :], proj[:, :, 2, :]

        q = q.view(tgt_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
        k = k.view(tgt_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
        v = v.view(tgt_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
        # [bsz, num_heads, tgt_len, head_dim]

        attn_weight = jt.matmul(q, k.transpose(-2, -1) / math.sqrt(head_dim))
        if attn_mask is not None:
            attn_mask = jt.reshape(attn_mask, (1, 1, tgt_len, tgt_len))
            attn_weight += attn_mask
        # [bsz, num_heads, tgt_len, tgt_len]
        attn_weight = softmax(attn_weight, dim=-1)
        attn_output = jt.matmul(attn_weight, v)
        # [bsz, num_heads, tgt_len, head_dim]

        attn_output = attn_output.permute(2, 0, 1, 3).reshape(bsz * tgt_len, embed_dim)

        attn_output = self.out_proj(attn_output)
        attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
        return attn_output

class QuickGELU(nn.Module):
    def execute(self, x):
        return x * jt.sigmoid(1.702 * x)

class MLP(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.c_fc = nn.Linear(d_model, d_model * 4)
        self.gelu = QuickGELU()
        self.c_proj = nn.Linear(d_model * 4, d_model)

    def execute(self, x):
        return self.c_proj(self.gelu(self.c_fc(x)))

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model, n_head, attn_mask):
        super().__init__()

        self.attn = MultiheadAttention(d_model, n_head)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model)
        self.ln_2 = nn.LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x):
        assert self.attn_mask is None or self.attn_mask.dtype == x.dtype
        return self.attn(x, attn_mask=self.attn_mask)

    def execute(self, x):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


### Vision Transformer

原始的Vision Transformer的结构如下图所示：
![](imgs/ViT.png)

实际实现中后续的MLP Head被替换为了一个简单的投影。

在这个部分中，你需要实现TODO(1)，完成图像切块、添加额外的`Class Embedding`以及添加`Position Embedding`


In [7]:
class Transformer(nn.Module):
    def __init__(self, width, layers, heads, attn_mask=None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[
            ResidualAttentionBlock(width, heads, attn_mask)
            for _ in range(layers)
        ])

    def execute(self, x):
        return self.resblocks(x)

class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int,
                 layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=width,
                               kernel_size=patch_size,
                               stride=patch_size,
                               bias=False)

        scale = width**-0.5
        self.class_embedding = scale * jt.randn((width))
        self.positional_embedding = scale * jt.randn(
            ((input_resolution // patch_size)**2 + 1, width))
        self.ln_pre = nn.LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = nn.LayerNorm(width)
        self.proj = scale * jt.randn((width, output_dim))

    def execute(self, x:jt.Var):
        """VisionTransformer forward

        Args:
            x (jt.Var): Input image with shape [B, C, H, W] (B, 3, 224, 224)

        Returns:
            jt.Var: Output image feature with shape [B, output_dim]
        """
        # TODO(1): Convert the image into a sequence.
        # Your code starts here

        # step1: use conv1 to convert x into a var with shape [B, width, grid, grid],
        # where grid == H // patch_size == W // patch_size
        x = self.conv1(x)

        # step2: reshape and permute x into a var with shape [B, grid ** 2, width]
        B, width, grid_h, grid_w = x.shape
        x = x.reshape(B, width, -1).permute(0, 2, 1)  

        # step3: concat class_embedding (with shape [width,]) and x,
        # so that the size of x becomes [B, 1 + grid ** 2, width]
        class_embedding = self.class_embedding.unsqueeze(0).unsqueeze(0).expand(B, 1, -1)  
        x = jt.concat([class_embedding, x], dim=1) 

        # step4: add the posional_embedding (with shape [grid ** 2 + 1, width]) to x
        x = x + self.positional_embedding.unsqueeze(0)  # [B, 1+grid**2, width]

        # Your code ends here

        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x[:, 0, :])
        x = x @ self.proj

        return x

img = jt.randn(1, 3, 224, 224)
vit = VisionTransformer(224, 32, 768, 12, 768 // 64, 512)
# [1, 512]
print(vit(img).shape)


[1,512,]


### CLIP模型

这个部分为CLIP模型的实现，它能够将文本和图像映射为特征向量。

CLIP模型的预训练需要大量的资源，本次作业中会直接加载预训练好的CLIP模型。

In [8]:
class CLIP(nn.Module):

    def __init__(
            self,
            embed_dim: int,
            # vision
            image_resolution: int,
            vision_layers: Union[Tuple[int, int, int, int], int],
            vision_width: int,
            vision_patch_size: int,
            # text
            context_length: int,
            vocab_size: int,
            transformer_width: int,
            transformer_heads: int,
            transformer_layers: int):
        super().__init__()

        self.context_length = context_length

        vision_heads = vision_width // 64
        self.visual = VisionTransformer(input_resolution=image_resolution,
                                        patch_size=vision_patch_size,
                                        width=vision_width,
                                        layers=vision_layers,
                                        heads=vision_heads,
                                        output_dim=embed_dim)

        self.transformer = Transformer(width=transformer_width,
                                       layers=transformer_layers,
                                       heads=transformer_heads,
                                       attn_mask=self.build_attention_mask())

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = jt.empty(
            (self.context_length, transformer_width))
        self.ln_final = nn.LayerNorm(transformer_width)

        self.text_projection = jt.empty((transformer_width, embed_dim))
        self.logit_scale = jt.ones([]) * np.log(1 / 0.07)

    def build_attention_mask(self):
        mask = jt.empty((self.context_length, self.context_length))
        mask.fill_(float("-inf"))
        mask = jt.triu(mask, 1)  # zero out the lower diagonal
        return mask

    @property
    def dtype(self):
        return self.visual.conv1.weight.dtype

    def encode_image(self, image):
        return self.visual(image)

    def encode_text(self, text):
        x = self.token_embedding(text)

        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[jt.arange(x.shape[0]),
              text.argmax(dim=-1)[0]] @ self.text_projection
        return x

def build_model(state_dict: dict):
    model = CLIP(
        embed_dim=512,
        image_resolution=224, vision_layers=12, vision_width=768, vision_patch_size=32,
        context_length=77, vocab_size=49408,
        transformer_width=512, transformer_heads=8, transformer_layers=12
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    model.load_parameters(state_dict)
    model.eval()

    return model

### 模型加载与数据预处理

In [9]:
def _resize(img:Image.Image, size, mode):
    w, h = img.size

    short, long = (w, h) if w <= h else (h, w)
    if short == size:
        return img

    new_short, new_long = size, int(size * long / short)
    new_w, new_h = (new_short, new_long) if w <= h else (new_long,
                                                            new_short)
    size = (new_h, new_w)
    return resize(img, size, mode)

def _to_tensor(data):
    data = np.asarray(data)
    if len(data.shape) < 3:
        data = np.expand_dims(data, -1)
    return jt.Var(data)

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        partial(_resize, size=n_px, mode=Image.BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ImageNormalize((0.48145466, 0.4578275, 0.40821073),
                       (0.26862954, 0.26130258, 0.27577711)),
        _to_tensor
    ])

def load(name):
    state_dict = jt.load(name)

    model = build_model(state_dict)
    return model, _transform(model.visual.input_resolution)

clip_model, preprocess = load("ViT-B-32.pkl")



Compiling Operators(1/1) used: 2.45s eta:    0s 


In [10]:
class CLIPDataset(Dataset):
    def __init__(self, root, transform, allowed_names:set, batch_size=16, shuffle=False, **kwargs):
        super().__init__(batch_size, shuffle, **kwargs)
        self.root = root
        self.transform = transform
        self.classes = sorted([int(d.name) for d in os.scandir(root) if d.is_dir()])
        self.classes = [str(n) for n in self.classes]
        self.class_to_idx = {v:k for k,v in enumerate(self.classes)}
        self.imgs = []
        image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'))
        
        for i, class_name in enumerate(self.classes):
            class_dir = os.path.join(root, class_name)
            assert class_name == str(i), "{} != {}".format(class_name, str(i))
            for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)):
                for fname in sorted(fnames):
                    if os.path.splitext(fname)[-1].lower() in image_exts:
                        if os.path.splitext(fname)[0].lower() not in allowed_names:
                            continue
                        path = os.path.join(class_dir, fname)
                        self.imgs.append((path, i))
        self.set_attrs(total_len=len(self.imgs))
        
    def __getitem__(self, k):
        with open(self.imgs[k][0], 'rb') as f:
            img = Image.open(f).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, self.imgs[k][1]

train_names = set([str(i) for i in range(1, 5)])
val_names = set([str(i) for i in range(5, 15)])
train_dataset = CLIPDataset('Dataset/Images/', transform=preprocess, allowed_names=train_names, batch_size=16, shuffle=True)
val_dataset = CLIPDataset('Dataset/Images/', transform=preprocess, allowed_names=val_names, batch_size=16, shuffle=False)


### 图像特征生成与评估
本次作业中，我们使用TOP1准确率和TOP5准确率来评估方法进行图像分类的能力。

其中，TOP1准确率为 分类正确图像数量/图片总数，TOP5准确率为 预测概率最高的5个类别中有正确类别的图像数量/图片总数。


In [11]:
def evaluate(predictions:np.ndarray, labels:np.ndarray, return_result=False):
    assert predictions.shape[0] == labels.shape[0]
    ranks = predictions.argsort(-1)
    top_5 = ranks[:, -1:-6:-1]
    top_1 = top_5[:, :1]
    target = labels[:, np.newaxis]
    top_5_acc = np.sum(top_5 == target, axis=-1) > 0
    top_1_acc = np.sum(top_1 == target, axis=-1) > 0
    top_5_acc = np.sum(top_5_acc) / top_5_acc.shape[0]
    top_1_acc = np.sum(top_1_acc) / top_1_acc.shape[0]
    print("Top1 Acc: {:.3f} %, Top5 Acc: {:.3f} %".format(top_1_acc * 100, top_5_acc * 100))
    if return_result:
        return top_1_acc, top_5_acc

def get_image_features(dataset, model:CLIP):
    features, labels = [], []
    for img, label in tqdm(dataset, total=len(dataset)):
        feature = model.encode_image(img)
        features.append(feature)
        labels.append(label)
    features = jt.concat(features).numpy()
    labels = jt.concat(labels).numpy()
    return features, labels

def normalize(features):
    return features / np.linalg.norm(features, ord=2, axis=-1, keepdims=True)

train_features, train_labels = get_image_features(train_dataset, clip_model)
val_features, val_labels = get_image_features(val_dataset, clip_model)

 78%|███████▊  | 73/94 [00:07<00:02,  9.41it/s]
Compiling Operators(3/42) used: 3.31s eta: 43.1s 8/42) used: 4.32s eta: 18.4s 10/42) used: 5.33s eta: 17.1s 11/42) used: 6.34s eta: 17.9s 16/42) used: 7.34s eta: 11.9s 19/42) used: 9.35s eta: 11.3s 21/42) used: 10.4s eta: 10.4s 23/42) used: 11.4s eta: 9.39s 27/42) used: 12.4s eta: 6.89s 28/42) used: 13.4s eta: 6.71s 30/42) used: 14.4s eta: 5.77s 33/42) used: 15.4s eta: 4.21s 37/42) used: 16.4s eta: 2.22s 38/42) used: 17.4s eta: 1.84s 40/42) used: 18.4s eta: 0.922s 41/42) used: 19.5s eta: 0.474s 42/42) used: 22.5s eta:    0s 
100%|██████████| 94/94 [00:33<00:00,  2.78it/s]

Compiling Operators(2/3) used: 3.31s eta: 1.66s 3/3) used: 4.32s eta:    0s 

Compiling Operators(1/1) used: 2.48s eta:    0s 
100%|██████████| 234/234 [00:26<00:00,  8.96it/s]


### Linear Probe
Linear Probe方法不使用文本特征，而是直接利用逻辑回归对图像特征进行分类。


In [12]:
classifier = LogisticRegression(random_state=0,
                                C=8.960,
                                max_iter=1000,
                                verbose=1)
classifier.fit(normalize(train_features), train_labels)
predictions = classifier.predict_proba(normalize(val_features))
evaluate(predictions, val_labels)


Top1 Acc: 64.492 %, Top5 Acc: 86.471 %


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    3.4s finished


### Zero-shot分类
通过比较类别文本特征和图片特征的相似度，模型无需进行额外训练，就可以预测图片属于那一份文本对应的类别。

在这个部分中，你需要实现`TODO(2)`，通过比较文本特征和图片特征进行图像分类。

具体来说，文本特征和图片特征之间的比较是通过计算它们的余弦相似度实现的。余弦相似度的计算公式为：
$$
cosine\_similarity(A, B) = \frac{A\cdot B}{|A|\cdot |B|}
$$
其中$A$和$B$是形状相同的一维向量，$\cdot$ 代表点积/标量乘法。

In [13]:
classes = open('Dataset/classes.txt').read().splitlines()
new_classes = []
basenames = ['Animal', 'Thu-dog', 'Caltech-101', 'Food-101']
for c in classes:
    assert len(new_classes) == int(c.split(' ')[1])
    c = c.split(' ')[0]
    for bname in basenames:
        if c.startswith(bname):
            c = c[len(bname)+1:]
    c = 'a photo of ' + c
    new_classes.append(c)

# TODO(2): Perform image classification by directly comparing image features with text features.
# Your code starts here
# step1: tokenize the new_classes
text_tokens = tokenize(new_classes)

# step2: encode the text
text_features = clip_model.encode_text(text_tokens)
text_features_norm = normalize(text_features.numpy())

# step3: generate prediction results by compare features.
# Note: predictions should be a Numpy array with shape [N_imgs, N_classes],
# where predictions[i, j] is the cosine similarity between the image i and the classification j.
val_features_norm = normalize(val_features)
predictions = val_features_norm @ text_features_norm.T  

# Your code ends here
evaluate(predictions, val_labels)



Compiling Operators(3/11) used: 3.31s eta: 8.83s 8/11) used: 4.32s eta: 1.62s 10/11) used: 5.32s eta: 0.532s 11/11) used: 6.33s eta:    0s 


Top1 Acc: 67.888 %, Top5 Acc: 88.102 %


### 使用Adapter进行微调
为了能够同时使用文本特征以及训练数据，我们添加一个参数量很小的可训练的Adapter，对图像特征进行微调。

在这个部分中，你需要实现`TODO(3)`。

相较于Zero-shot部分，其不同之处在于：
1. 需要使用Adapter处理图像特征
2. 需要乘上`logit_scale`用于训练。


In [14]:
class AdapterDataset(Dataset):
    def __init__(self, features, labels, batch_size=16, shuffle=True, **kwargs):
        super().__init__(batch_size, shuffle, **kwargs)
        self.features = jt.array(features)
        self.labels = jt.array(labels)
        self.total_len = features.shape[0]

    def __getitem__(self, index):
        return self.features[index], self.labels[index]

class Adapter(Module):
    def __init__(self, in_channel, hidden_channel, alpha):
        self.fc = nn.Sequential(
            nn.Linear(in_channel, hidden_channel),
            nn.ReLU(),
            nn.Linear(hidden_channel, in_channel),
            nn.ReLU(),
        )
        self.alpha = alpha

    def execute(self, feature):
        return feature * self.alpha + self.fc(feature) * (1 - self.alpha)

class CLIPAdapter(Module):
    def __init__(self, in_channel, hidden_channel, logit_scale, alpha=0.2):
        self.image_adapter = Adapter(in_channel, hidden_channel, alpha=alpha)
        self.logit_scale = logit_scale.exp()
        self.logit_scale.stop_grad()

    def execute(self, image_feature, text_feature):
        # TODO(3): Execution of CLIP Adapter
        # Your code starts here
        # step1: process the image_feature with image_adapter
        adapted_image_feature = self.image_adapter(image_feature)

        # step2: normalize the image_feature and the text_feature
        image_feature_norm = adapted_image_feature / jt.norm(adapted_image_feature, dim=1, keepdim=True)
        text_feature_norm = text_feature / jt.norm(text_feature, dim=1, keepdim=True)

        # step3: Compute the similarity between `image_feature` and `text_feature`
        # using using matrix multiplication, and scale the result by the factor `logit_scale``
        logits = self.logit_scale * jt.matmul(image_feature_norm, text_feature_norm.transpose(1, 0))

        # Your code ends here
        return logits

def train_adapter(model:CLIP, train_features, train_labels, text_features, n_epochs=5):
    dataset = AdapterDataset(train_features, train_labels, shuffle=True)
    adapter = CLIPAdapter(512, 128, model.logit_scale)
    adapter.train()
    optimizer = jt.optim.Adam(adapter.image_adapter.parameters(), lr=0.0003)
    text_features = jt.array(text_features)
    for e_id in range(n_epochs):
        for img, text in dataset:
            pred = adapter(img, text_features)
            loss = nn.cross_entropy_loss(pred, text)
            loss.sync()
            optimizer.step(loss)
    return adapter

def run_adapter():
    adapter = train_adapter(clip_model, train_features, train_labels, text_features)
    adapter.eval()
    with jt.no_grad():
        predictions = adapter(jt.array(val_features), jt.array(text_features)).numpy()
    top1, _ = evaluate(predictions, val_labels, return_result=True)
    return top1

baselines = [run_adapter() for _ in range(10)]
print("Mean Top1: {:.3f} %".format(sum(baselines) / 10 * 100))


Compiling Operators(4/16) used: 3.31s eta: 9.94s 10/16) used: 4.32s eta: 2.59s 14/16) used: 7.33s eta: 1.05s 16/16) used: 8.34s eta:    0s 

Compiling Operators(1/16) used: 3.31s eta: 49.7s 9/16) used: 4.32s eta: 3.36s 10/16) used: 5.33s eta:  3.2s 11/16) used: 7.34s eta: 3.33s 14/16) used: 8.34s eta: 1.19s 15/16) used: 9.35s eta: 0.623s 16/16) used: 11.4s eta:    0s 

Compiling Operators(3/3) used: 4.31s eta:    0s 


Top1 Acc: 70.722 %, Top5 Acc: 91.765 %
Top1 Acc: 71.283 %, Top5 Acc: 91.765 %
Top1 Acc: 71.283 %, Top5 Acc: 92.246 %
Top1 Acc: 71.604 %, Top5 Acc: 92.059 %
Top1 Acc: 71.390 %, Top5 Acc: 91.711 %
Top1 Acc: 71.738 %, Top5 Acc: 92.273 %
Top1 Acc: 71.096 %, Top5 Acc: 92.594 %
Top1 Acc: 71.444 %, Top5 Acc: 92.112 %
Top1 Acc: 71.230 %, Top5 Acc: 91.791 %
Top1 Acc: 71.444 %, Top5 Acc: 91.845 %
Mean Top1: 71.324 %
