In [None]:
import os

try:
    arquivos = os.listdir('/kaggle/working/')
    arquivos.sort()
    print(arquivos)
except FileNotFoundError:
    print("The directory /kaggle/working/ was not found.")
except Exception as e:
    print(f"An error occurred: {e}")

## simple_tokenizer.py

In [None]:
! pip install ftfy

In [None]:
import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re



@lru_cache()
def default_bpe():
    return os.path.join('/kaggle/input/cafe-repo/Generalizable-FER-main/code/clip/bpe_simple_vocab_16e6.txt')


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe()):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = open(bpe_path, 'rb').read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text


## model.py

In [None]:
from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )

        return x[0]


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.avgpool = nn.AvgPool2d(2)
        self.relu = nn.ReLU(inplace=True)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        def stem(x):
            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
                x = self.relu(bn(conv(x)))
            x = self.avgpool(x)
            return x

        x = x.type(self.conv1.weight.dtype)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.attnpool(x)

        return x


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)


class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x


class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

        self.context_length = context_length

        if isinstance(vision_layers, (tuple, list)):
            vision_heads = vision_width * 32 // 64
            self.visual = ModifiedResNet(
                layers=vision_layers,
                output_dim=embed_dim,
                heads=vision_heads,
                input_resolution=image_resolution,
                width=vision_width
            )
        else:
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(
                input_resolution=image_resolution,
                patch_size=vision_patch_size,
                width=vision_width,
                layers=vision_layers,
                heads=vision_heads,
                output_dim=embed_dim
            )

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        if isinstance(self.visual, ModifiedResNet):
            if self.visual.attnpool is not None:
                std = self.visual.attnpool.c_proj.in_features ** -0.5
                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)

            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
                for name, param in resnet_block.named_parameters():
                    if name.endswith("bn3.weight"):
                        nn.init.zeros_(param)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    @property
    def dtype(self):
        return self.visual.conv1.weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text


def convert_weights(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.half()

    model.apply(_convert_weights_to_fp16)


def build_model(state_dict: dict):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
    else:
        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

    model = CLIP(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    convert_weights(model)
    model.load_state_dict(state_dict)
    return model.eval()


## clip.py

In [None]:
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm



try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

In [None]:
if torch.__version__.split(".") < ["1", "7", "1"]:
    warnings.warn("PyTorch version 1.7.1 or higher is recommended")

In [None]:
def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    expected_sha256 = url.split("/")[-2]
    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
            return download_target
        else:
            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
        raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

    return download_target

In [None]:
def _convert_image_to_rgb(image):
    return image.convert("RGB")

In [None]:
def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])


In [None]:
def available_models() -> List[str]:
    """Returns the names of available CLIP models"""
    return list(_MODELS.keys())

In [None]:
def patch_device(module):
        try:
            graphs = [module.graph] if hasattr(module, "graph") else []
        except RuntimeError:
            graphs = []

        if hasattr(module, "forward1"):
            graphs.append(module.forward1.graph)

        for graph in graphs:
            for node in graph.findAllNodes("prim::Constant"):
                if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
                    node.copyAttributes(device_node)

In [None]:
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
    """Load a CLIP model

    Parameters
    ----------
    name : str
        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict

    device : Union[str, torch.device]
        The device to put the loaded model

    jit : bool
        Whether to load the optimized JIT model or more hackable non-JIT model (default).

    download_root: str
        path to download the model files; by default, it uses "~/.cache/clip"

    Returns
    -------
    model : torch.nn.Module
        The CLIP model

    preprocess : Callable[[PIL.Image], torch.Tensor]
        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
    """
    if name in _MODELS:
        model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
        state_dict = None
    except RuntimeError:
        # loading saved state dict
        if jit:
            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
            jit = False
        state_dict = torch.load(model_path, map_location="cpu")

    if not jit:
        model = build_model(state_dict or model.state_dict()).to(device)
        if str(device) == "cpu":
            model.float()
        return model, _transform(model.visual.input_resolution)

    # patch the device names
    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
    device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

    model.apply(patch_device)
    patch_device(model.encode_image)
    patch_device(model.encode_text)

    # patch dtype to float32 on CPU
    if str(device) == "cpu":
        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
        float_node = float_input.node()

        def patch_float(module):
            try:
                graphs = [module.graph] if hasattr(module, "graph") else []
            except RuntimeError:
                graphs = []

            if hasattr(module, "forward1"):
                graphs.append(module.forward1.graph)

            for graph in graphs:
                for node in graph.findAllNodes("aten::to"):
                    inputs = list(node.inputs())
                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
                        if inputs[i].node()["value"] == 5:
                            inputs[i].node().copyAttributes(float_node)

        model.apply(patch_float)
        patch_float(model.encode_image)
        patch_float(model.encode_text)

        model.float()

    return model, _transform(model.input_resolution.item())


def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result


In [None]:
_tokenizer = SimpleTokenizer()

In [None]:
__all__ = ["available_models", "load", "tokenize"]

_MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}

## Main code

In [None]:
import sys
sys.path.append('/kaggle/input/cafe-repo/Generalizable-FER-main/code/clip')

In [None]:
import os
import cv2
import csv
import math
import random
import numpy as np
import pandas as pd
import argparse
import pickle
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import torch.utils.data as data
import torch.nn.functional as F
from torch.autograd import Variable
import pickle
from torch.autograd import Variable
import torch.utils.data as data
import pandas as pd
import random
from torchvision import transforms

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = load("ViT-B/32", device=device)

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--raf_path', type=str, default='../../data/raf-basic', help='raf_dataset_path')
parser.add_argument('--resnet50_path', type=str, default='../../data/resnet50_ft_weight.pkl', help='pretrained_backbone_path')
parser.add_argument('--label_path', type=str, default='list_patition_label.txt', help='label_path')
parser.add_argument('--workers', type=int, default=2, help='number of workers')
parser.add_argument('--batch_size', type=int, default=32, help='batch_size')
parser.add_argument('--w', type=int, default=7, help='width of the attention map')
parser.add_argument('--h', type=int, default=7, help='height of the attention map')
parser.add_argument('--gpu', type=int, default=0, help='the number of the device')
parser.add_argument('--lam', type=float, default=5, help='kl_lambda')
parser.add_argument('--epochs', type=int, default=60, help='number of epochs')
args = parser.parse_args(args=[])

### Arquitetura do Modelo

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

In [None]:
class BasicBlock(nn.Module):

    expansion = 1

    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3,
                               stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3,
                               stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace = True)

        if downsample:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1,
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None

        self.downsample = downsample

    def forward(self, x):

        i = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            i = self.downsample(i)

        x += i
        x = self.relu(x)

        return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, n_blocks, channels, output_dim):
        super().__init__()


        self.in_channels = channels[0]

        assert len(n_blocks) == len(channels) == 4

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride = 2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride = 2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride = 2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.in_channels, output_dim)

    def get_resnet_layer(self, block=BasicBlock, n_blocks=[2,2,2,2], channels=[64, 128, 256, 512], stride = 1):

        layers = []

        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False

        layers.append(block(self.in_channels, channels, stride, downsample))

        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels

        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)

        return x, h

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [None]:
def Mask(nb_batch):
    bar = []
    for i in range(7):
        foo = [1] * 63 + [0] *  10
        if i == 6:
            foo = [1] * 64 + [0] *  10
        random.shuffle(foo)  #### generate mask
        bar += foo
    bar = [bar for i in range(nb_batch)]
    bar = np.array(bar).astype("float32")
    bar = bar.reshape(nb_batch,512,1,1)
    bar = torch.from_numpy(bar)
    bar = bar.cuda()
    bar = Variable(bar)
    return bar

In [None]:
###### channel separation and channel diverse loss
def supervisor(x, targets, cnum):
    branch = x
    branch = branch.reshape(branch.size(0),branch.size(1), 1, 1)
    branch = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch)
    branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3))
    loss_2 = 1.0 - 1.0*torch.mean(torch.sum(branch,2))/cnum # set margin = 3.0

    mask = Mask(x.size(0))
    branch_1 = x.reshape(x.size(0),x.size(1), 1, 1) * mask
    branch_1 = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch_1)
    branch_1 = branch_1.view(branch_1.size(0), -1)
    loss_1 = nn.CrossEntropyLoss()(branch_1, targets)
    return [loss_1, loss_2]

In [None]:
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.parameter import Parameter

class my_MaxPool2d(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
                 return_indices=False, ceil_mode=False):
        super(my_MaxPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.dilation = dilation
        self.return_indices = return_indices
        self.ceil_mode = ceil_mode

    def forward(self, input):
        input = input.transpose(3,1)


        input = F.max_pool2d(input, self.kernel_size, self.stride,
                            self.padding, self.dilation, self.ceil_mode,
                            self.return_indices)
        input = input.transpose(3,1).contiguous()

        return input

    def __repr__(self):
        kh, kw = _pair(self.kernel_size)
        dh, dw = _pair(self.stride)
        padh, padw = _pair(self.padding)
        dilh, dilw = _pair(self.dilation)
        padding_str = ', padding=(' + str(padh) + ', ' + str(padw) + ')' \
            if padh != 0 or padw != 0 else ''
        dilation_str = (', dilation=(' + str(dilh) + ', ' + str(dilw) + ')'
                        if dilh != 0 and dilw != 0 else '')
        ceil_str = ', ceil_mode=' + str(self.ceil_mode)
        return self.__class__.__name__ + '(' \
            + 'kernel_size=(' + str(kh) + ', ' + str(kw) + ')' \
            + ', stride=(' + str(dh) + ', ' + str(dw) + ')' \
            + padding_str + dilation_str + ceil_str + ')'


class my_AvgPool2d(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                 count_include_pad=True):
        super(my_AvgPool2d, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.ceil_mode = ceil_mode
        self.count_include_pad = count_include_pad

    def forward(self, input):
        input = input.transpose(3,1)
        input = F.avg_pool2d(input, self.kernel_size, self.stride,
                            self.padding, self.ceil_mode, self.count_include_pad)
        input = input.transpose(3,1).contiguous()
        return input


    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'kernel_size=' + str(self.kernel_size) \
            + ', stride=' + str(self.stride) \
            + ', padding=' + str(self.padding) \
            + ', ceil_mode=' + str(self.ceil_mode) \
            + ', count_include_pad=' + str(self.count_include_pad) + ')'

In [None]:
class Model(nn.Module):
    def __init__(self, pretrained=True, num_classes=7, drop_rate=0, model_path=None):
        super(Model, self).__init__()

        res18 = ResNet(block = BasicBlock, n_blocks = [2,2,2,2], channels = [64, 128, 256, 512], output_dim=1000)
        msceleb_model = torch.load(model_path, map_location=torch.device('cpu'))
        state_dict = msceleb_model['state_dict']
        res18.load_state_dict(state_dict, strict=False)

        self.drop_rate = drop_rate
        self.features = nn.Sequential(*list(res18.children())[:-2])
        self.features2 = nn.Sequential(*list(res18.children())[-2:-1])

        fc_in_dim = list(res18.children())[-1].in_features  # original fc layer's in dimention 512
        self.fc = nn.Linear(fc_in_dim, num_classes)  # new fc layer 512x7

        self.parm={}
        for name,parameters in self.fc.named_parameters():
            print(name,':',parameters.size())
            self.parm[name]=parameters

    def forward(self, x, clip_model, targets, phase='train'):
        with torch.no_grad():
            image_features = clip_model.encode_image(x)

        x = self.features(x)
        feat = x

        x = self.features2(x)
        x = x.view(x.size(0), -1)
        ################### sigmoid mask (important)
        if phase=='train':
            MC_loss = supervisor(image_features * torch.sigmoid(x), targets, cnum=73)

        x = image_features * torch.sigmoid(x)
        out = self.fc(x)

        if phase=='train':
            return out, MC_loss
        else:
            return out, out

### Treinamento E Teste Codigos

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
from tqdm import tqdm

def train(model, train_loader, optimizer, scheduler, device):
  running_loss = 0.0
  iter_cnt = 0
  correct_sum = 0

  model.to(device)
  model.train()

  total_loss = []
  with tqdm(total=len(train_loader)) as pbar:
      for batch_i, (imgs1, labels) in enumerate(train_loader):
        imgs1 = imgs1.to(device)
        labels = labels.to(device)

        criterion = nn.CrossEntropyLoss(reduction='none')

        output, MC_loss = model(imgs1, clip_model, labels, phase='train')

        loss1 = nn.CrossEntropyLoss()(output, labels)

        loss = loss1 + 5 * MC_loss[1] + 1.5 * MC_loss[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_cnt += 1
        _, predicts = torch.max(output, 1)
        correct_num = torch.eq(predicts, labels).sum()
        correct_sum += correct_num
        running_loss += loss

        pbar.update(1)  # Update progress bar for each batch

  scheduler.step()
  running_loss = running_loss / iter_cnt
  acc = correct_sum.float() / float(train_loader.dataset.__len__())
  return acc, running_loss

setup_seed(3407)

In [None]:
def test(model, test_loader, device):
    with torch.no_grad():
        model.eval()

        running_loss = 0.0
        iter_cnt = 0
        correct_sum = 0
        data_num = 0


        for batch_i, (imgs1, labels) in enumerate(test_loader):
            imgs1 = imgs1.to(device)
            labels = labels.to(device)


            outputs, _ = model(imgs1, clip_model, labels, phase='test')


            loss = nn.CrossEntropyLoss()(outputs, labels)

            iter_cnt += 1
            _, predicts = torch.max(outputs, 1)

            correct_num = torch.eq(predicts, labels).sum()
            correct_sum += correct_num

            running_loss += loss
            data_num += outputs.size(0)

        running_loss = running_loss / iter_cnt
        test_acc = correct_sum.float() / float(data_num)
        
    return test_acc, running_loss

In [None]:
model_path = '/kaggle/input/cafe-repo/Generalizable-FER-main/models/resnet18_msceleb.pth'
model = Model(model_path = model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR

optimizer = optim.Adam(model.parameters(), lr=0.0002, weight_decay=0.0001)
scheduler = ExponentialLR(optimizer, gamma=0.9)

In [None]:
eval_transforms = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])

In [None]:
train_transforms = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.RandomHorizontalFlip(),
    transforms.RandomErasing(scale=(0.02, 0.25))
    ])

### Caregamento de Dataset RafDB

In [None]:
import torch.utils.data as data
import cv2
import pandas as pd
import os
# import image_utils
import random
import cv2
import numpy as np



class RafDataSet(data.Dataset):
    def __init__(self, raf_path, idxs_raf, idxs_test, dataidxs=None, train=True, transform=None, basic_aug=False, download=False):
        self.train = train
        self.dataidxs = dataidxs
        self.transform = transform
        self.raf_path = raf_path
        self.idxs_raf = idxs_raf
        self.idxs_test = idxs_test

        NAME_COLUMN = 0
        LABEL_COLUMN = 1
        df = pd.read_csv(os.path.join(self.raf_path, 'EmoLabel/list_patition_label.txt'), sep=' ', header=None)
        if self.train:
            dataset = df[df[NAME_COLUMN].str.startswith('train')]
        else:
            dataset = df[df[NAME_COLUMN].str.startswith('test')]
        file_names = dataset.iloc[:, NAME_COLUMN].values
        self.target = dataset.iloc[:, LABEL_COLUMN].astype(int).values - 1  # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
        self.target = np.array(self.target)

        self.file_paths = []
        for f in file_names:    # use raf-db aligned images for training/testing
            f = f.split(".")[0]
            f = f + "_aligned.jpg"
            print(f)
            path = os.path.join(self.raf_path, 'Image/aligned', f)
            self.file_paths.append(path)

        self.basic_aug = basic_aug
        ################
        self.file_paths = np.array(self.file_paths)
        if self.dataidxs is not None:
            self.file_paths = self.file_paths[self.dataidxs]
            self.target = self.target[self.dataidxs]
        else:
            self.file_paths = self.file_paths
        self.file_paths = self.file_paths.tolist()


    def __len__(self):
        return len(self.file_paths)

    def get_labels(self):
        return self.target

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        sample = cv2.imread(path)
        sample = sample[:, :, ::-1]  # BGR to RGB (Optional)
        target = self.target[idx]

        target = self.idxs_test[self.idxs_raf[target]]
        
        if self.transform is not None:
            
            sample = Image.fromarray(sample.copy())  # Convert NumPy array to PIL image
            sample = self.transform(sample)
        
        return sample, target  # , idx (Optional to return index)


In [None]:
emotion_to_index = {
            "surprise": 0,
            "fear": 1,
            "disgust": 2,
            "happiness": 3,
            "sadness": 4,
            "angry": 5,
            "neutral": 6
        }

index_to_emotion = {v: k for k, v in emotion_to_index.items()}

In [None]:
test_dataset = RafDataSet('/kaggle/input/eacdata/raf-basic', idxs_raf=index_to_emotion, idxs_test=emotion_to_index, train=False, transform=eval_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
train_dataset = RafDataSet('/kaggle/input/eacdata/raf-basic', train=True, transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=True,
                                           num_workers=args.workers,
                                           pin_memory=True)

In [None]:
import matplotlib.pyplot as plt

# Obtendo o batch de imagens e rótulos
for images, labels in test_loader:
    # Se você quiser mostrar apenas um batch
    break

# Definindo o layout para 4 linhas e 8 colunas
fig, axes = plt.subplots(4, 8, figsize=(20, 10))  # 4x8 layout
axes = axes.flatten()  # Flatten para facilitar a iteração

# Loop para exibir as imagens no grid
for i, (img, label) in enumerate(zip(images, labels)):
    if i >= len(axes):  # Se houver mais imagens do que subgráficos
        break

    # Convertendo a imagem para numpy e normalizando
    img_np = img.permute(1, 2, 0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

    # Exibindo a imagem
    ax = axes[i]
    ax.imshow(img_np)
    ax.axis('off')  # Desativar os eixos

    # Usando o mapa de rótulos para mostrar o nome da emoção
    label_name = index_to_emotion[label.item()]
    ax.set_title(f"{label_name}", fontsize=10)  # Título com o nome do label

# Ajustar o layout para não sobrepor as imagens
plt.tight_layout()
plt.show()



In [None]:
best_acc = 0
patience = 5  # Number of epochs to wait for improvement
no_improvement = 0

for i in range(1, args.epochs + 1):
    train_acc, train_loss = train(model, train_loader, optimizer, scheduler, device)
    test_acc, test_loss = test(model, test_loader, device)
    print('epoch: ', i, 'acc_test: ', test_acc, 'acc_train: ', train_acc)

    # Early stopping logic with patience
    if test_acc > best_acc:
        best_acc = test_acc
        no_improvement = 0  # Reset patience counter on improvement
        torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_best_RAFDB.pth")
    else:
        no_improvement += 1  # Increment patience counter on no improvement

    if no_improvement == patience:
        print(f"Early stopping after {i} epochs with no improvement in test accuracy")
        break  # Exit the training loop if patience is exhausted

    torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_final_RAFDB.pth")
    with open('results.txt', 'a') as f:
        f.write(str(i)+'_'+str(test_acc)+'\n')

### Caregamento de Dataset FER+

In [None]:
import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image


class FERPlusDataset(Dataset):
    def __init__(self, root_dir, idxs_fer, idxs_test, subset="FER2013Train", transform=None):
        """
        Classe para lidar com o dataset FERPlus.

        Args:
            root_dir (str): Diretório raiz do dataset (ex: 'FER2013Plus').
            subset (str): Subconjunto a ser usado ('FER2013Train', 'FER2013Test', 'FER2013Valid').
            transform (callable, optional): Transformações para aplicar nas imagens.
        """
        self.root_dir = root_dir
        self.subset = subset
        self.transform = transform
        self.idxs_fer = idxs_fer
        self.idxs_test = idxs_test

        # Caminhos para imagens e labels
        self.images_dir = os.path.join(root_dir, "Images", subset)
        self.labels_path = os.path.join(root_dir, "Labels", subset, "label.csv")

        # Carregar o arquivo de labels
        if not os.path.exists(self.labels_path):
            raise FileNotFoundError(f"Arquivo de labels não encontrado: {self.labels_path}")

        self.columns = [
            "image_name", "format", "neutral", "happiness", "surprise", "sadness",
            "anger", "disgust", "fear", "contempt", "unknown", "NF"
        ]

        self.labels = pd.read_csv(self.labels_path, header=None, names=self.columns)

        # Validar se os arquivos de imagem existem
        self.image_files = self.labels['image_name']

        # Dicionário para mapear emoções para índices
        self.emotion_to_index = {
            "neutral": 0,
            "happiness": 1,
            "surprise": 2,
            "sadness": 3,
            "anger": 4,
            "disgust": 5,
            "fear": 6
        }

    def get_single_label_filtered(self, row):
        """
        Obtém o índice do rótulo mais votado entre as emoções, excluindo "unknown" e "NF".

        Args:
            row (pd.Series): Linha do DataFrame de rótulos.

        Returns:
            int: Índice do rótulo mais votado.
        """
        # Filtrar rótulos "unknown" e "NF"
        emotion_columns = ["neutral", "happiness", "surprise", "sadness",
                           "anger", "disgust", "fear"]
        # Obter o nome do rótulo mais votado
        emotion_name = row[emotion_columns].idxmax()
        # Retornar o índice correspondente
        return self.emotion_to_index[emotion_name]

    def __len__(self):
        #return 1
        return len(self.labels)

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            raise NotImplementedError("Slices não são suportados nesta implementação.")

        # Obter caminho da imagem
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)

        # Carregar imagem
        image = Image.open(img_path).convert("RGB")

        # Aplicar transformações se existirem
        if self.transform:
            image = self.transform(image)
        
        # Obter o rótulo correspondente
        label_row = self.labels.iloc[idx]
        single_label = self.get_single_label_filtered(label_row)

        single_label = self.idxs_test[self.idxs_fer[single_label]]

        return image, single_label


In [None]:
root_dir = "/kaggle/input/ferplus/FER2013Plus"

In [None]:
emotion_to_index = {
            "neutral": 0,
            "happiness": 1,
            "surprise": 2,
            "sadness": 3,
            "anger": 4,
            "disgust": 5,
            "fear": 6
        }

index_to_emotion = {v: k for k, v in emotion_to_index.items()}

In [None]:
train_dataset = FERPlusDataset(root_dir=root_dir, idxs_fer=index_to_emotion, idxs_test=emotion_to_index, subset="FER2013Train", transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=True,
                                           num_workers=args.workers,
                                           pin_memory=True)

In [None]:
test_dataset = FERPlusDataset(root_dir=root_dir, idxs_fer=index_to_emotion, idxs_test=emotion_to_index, subset="FER2013Test", transform=eval_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
import matplotlib.pyplot as plt

# Obtendo o batch de imagens e rótulos
for images, labels in train_loader:
    # Se você quiser mostrar apenas um batch
    break

# Definindo o layout para 4 linhas e 8 colunas
fig, axes = plt.subplots(4, 8, figsize=(20, 10))  # 4x8 layout
axes = axes.flatten()  # Flatten para facilitar a iteração

# Loop para exibir as imagens no grid
for i, (img, label) in enumerate(zip(images, labels)):
    if i >= len(axes):  # Se houver mais imagens do que subgráficos
        break

    # Convertendo a imagem para numpy e normalizando
    img_np = img.permute(1, 2, 0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

    # Exibindo a imagem
    ax = axes[i]
    ax.imshow(img_np)
    ax.axis('off')  # Desativar os eixos

    # Usando o mapa de rótulos para mostrar o nome da emoção
    label_name = index_to_emotion[label.item()]
    ax.set_title(f"{label_name}", fontsize=10)  # Título com o nome do label

# Ajustar o layout para não sobrepor as imagens
plt.tight_layout()
plt.show()

In [None]:
best_acc = 0
patience = 5  # Number of epochs to wait for improvement
no_improvement = 0

for i in range(1, args.epochs + 1):
    train_acc, train_loss = train(model, train_loader, optimizer, scheduler, device)
    test_acc, test_loss = test(model, test_loader, device)
    print('epoch: ', i, 'acc_test: ', test_acc, 'acc_train: ', train_acc)

    # Early stopping logic with patience
    if test_acc > best_acc:
        best_acc = test_acc
        no_improvement = 0  # Reset patience counter on improvement
        torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_best_FERPlus.pth")
    else:
        continue
        #no_improvement += 1  # Increment patience counter on no improvement

    if no_improvement == patience:
        print(f"Early stopping after {i} epochs with no improvement in test accuracy")
        break  # Exit the training loop if patience is exhausted

    torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_final_FERPlus.pth")
    with open('results.txt', 'a') as f:
        f.write(str(i)+'_'+str(test_acc)+'\n')

### test on AffectNet

In [None]:
emotion_to_index = {
            "angry": 0,
            "disgust": 1,
            "fear": 2,
            "happiness": 3,
            "sadness": 4,
            "surprise": 5,
            "neutral": 6
        }

index_to_emotion = {v: k for k, v in emotion_to_index.items()}

In [None]:
from torch.utils.data import Dataset

class AffectNetDataset(Dataset):
    def __init__(self, root_dir, idx_aff, idx_test, split="train", transform=None):
        """
        Args:
            root_dir (str): Diretório raiz contendo as pastas train, test e valid.
            split (str): Qual partição carregar ("train", "test" ou "valid").
            transform (callable, optional): Transformações a serem aplicadas às imagens.
        """
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.idxs_test = idx_test
        self.idxs_aff = idx_aff
        
        self.class_map = {
            "angry": 0,
            "disgust": 1,
            "fear": 2,
            "happiness": 3,
            "sadness": 4,
            "surprise": 5,
            "neutral": 6
        }
        
        self.samples = []
        for class_idx in range(len(self.class_map)):  # Pastas nomeadas por índice
            class_path = os.path.join(self.root_dir, str(class_idx))
            if os.path.exists(class_path):
                for filename in os.listdir(class_path):
                    if filename.endswith(('.png', '.jpg', '.jpeg')):  # Filtra apenas imagens
                        self.samples.append((os.path.join(class_path, filename), class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")

        label = self.idxs_test[self.idxs_aff[label]]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [None]:
root_dir = "/kaggle/input/affectnetaligned/AffectNetCustom"

In [None]:
train_dataset = AffectNetDataset(root_dir=root_dir, idx_aff=index_to_emotion, idx_test=emotion_to_index, split="train", transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=True,
                                           num_workers=args.workers,
                                           pin_memory=True)

In [None]:
test_dataset = AffectNetDataset(root_dir=root_dir, idx_aff=index_to_emotion, idx_test=emotion_to_index, split="test", transform=eval_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
import matplotlib.pyplot as plt

# Obtendo o batch de imagens e rótulos
for images, labels in train_loader:
    # Se você quiser mostrar apenas um batch
    break

# Definindo o layout para 4 linhas e 8 colunas
fig, axes = plt.subplots(4, 8, figsize=(20, 10))  # 4x8 layout
axes = axes.flatten()  # Flatten para facilitar a iteração

# Loop para exibir as imagens no grid
for i, (img, label) in enumerate(zip(images, labels)):
    if i >= len(axes):  # Se houver mais imagens do que subgráficos
        break

    # Convertendo a imagem para numpy e normalizando
    img_np = img.permute(1, 2, 0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

    # Exibindo a imagem
    ax = axes[i]
    ax.imshow(img_np)
    ax.axis('off')  # Desativar os eixos

    # Usando o mapa de rótulos para mostrar o nome da emoção
    label_name = index_to_emotion[label.item()]
    ax.set_title(f"{label_name}", fontsize=10)  # Título com o nome do label

# Ajustar o layout para não sobrepor as imagens
plt.tight_layout()
plt.show()

In [None]:
best_acc = 0
patience = 5  # Number of epochs to wait for improvement
no_improvement = 0

for i in range(1, args.epochs + 1):
    train_acc, train_loss = train(model, train_loader, optimizer, scheduler, device)
    test_acc, test_loss = test(model, test_loader, device)
    print('epoch: ', i, 'acc_test: ', test_acc, 'acc_train: ', train_acc)

    # Early stopping logic with patience
    if test_acc > best_acc:
        best_acc = test_acc
        no_improvement = 0  # Reset patience counter on improvement
        torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_best_AffectNet.pth")
    else:
        continue
        #no_improvement += 1  # Increment patience counter on no improvement

    if no_improvement == patience:
        print(f"Early stopping after {i} epochs with no improvement in test accuracy")
        break  # Exit the training loop if patience is exhausted

    torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_final_AffectNet.pth")
    with open('results.txt', 'a') as f:
        f.write(str(i)+'_'+str(test_acc)+'\n')

### train on MMADataset

In [None]:
emotion_to_index = {
            "angry": 0,
            "disgust": 1,
            "fear": 2,
            "happiness": 3,
            "sadness": 5,
            "surprise": 6,
            "neutral": 4
        }

index_to_emotion = {v: k for k, v in emotion_to_index.items()}

In [None]:
class MMADataset(Dataset):
    def __init__(self, root_dir, idx_mma, idx_test, split="train", transform=None):
        """
        Args:
            root_dir (str): Diretório raiz contendo as pastas train, test e valid.
            split (str): Qual partição carregar ("train", "test" ou "valid").
            transform (callable, optional): Transformações a serem aplicadas às imagens.
        """
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.classes = sorted(os.listdir(self.root_dir))  # Lista de emoções
        self.idxs_test = idx_test
        self.idxs_mma = idx_mma
        
        self.samples = []
        for class_idx, class_name in enumerate(self.classes):
            class_path = os.path.join(self.root_dir, class_name)
            for filename in os.listdir(class_path):
                if filename.endswith(('.png', '.jpg', '.jpeg')):  # Filtra apenas imagens
                    self.samples.append((os.path.join(class_path, filename), class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)

        label = self.idxs_test[self.idxs_mma[label]]
        
        return image, label

In [None]:
root_dir = "/kaggle/input/mma-facial-expression/MMAFEDB"

In [None]:
train_dataset = MMADataset(root_dir=root_dir, idx_mma=index_to_emotion, idx_test=emotion_to_index, split="train", transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=True,
                                           num_workers=args.workers,
                                           pin_memory=True)

In [None]:
test_dataset = MMADataset(root_dir=root_dir, idx_mma=index_to_emotion, idx_test=emotion_to_index, split="test", transform=eval_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
import matplotlib.pyplot as plt

# Obtendo o batch de imagens e rótulos
for images, labels in train_loader:
    # Se você quiser mostrar apenas um batch
    break

# Definindo o layout para 4 linhas e 8 colunas
fig, axes = plt.subplots(4, 8, figsize=(20, 10))  # 4x8 layout
axes = axes.flatten()  # Flatten para facilitar a iteração

# Loop para exibir as imagens no grid
for i, (img, label) in enumerate(zip(images, labels)):
    if i >= len(axes):  # Se houver mais imagens do que subgráficos
        break

    # Convertendo a imagem para numpy e normalizando
    img_np = img.permute(1, 2, 0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

    # Exibindo a imagem
    ax = axes[i]
    ax.imshow(img_np)
    ax.axis('off')  # Desativar os eixos

    # Usando o mapa de rótulos para mostrar o nome da emoção
    label_name = index_to_emotion[label.item()]
    ax.set_title(f"{label_name}", fontsize=10)  # Título com o nome do label

# Ajustar o layout para não sobrepor as imagens
plt.tight_layout()
plt.show()

In [None]:
best_acc = 0
patience = 5  # Number of epochs to wait for improvement
no_improvement = 0

for i in range(1, args.epochs + 1):
    train_acc, train_loss = train(model, train_loader, optimizer, scheduler, device)
    test_acc, test_loss = test(model, test_loader, device)
    print('epoch: ', i, 'acc_test: ', test_acc, 'acc_train: ', train_acc)

    # Early stopping logic with patience
    if test_acc > best_acc:
        best_acc = test_acc
        no_improvement = 0  # Reset patience counter on improvement
        torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_best_MMA.pth")
    else:
        no_improvement += 1  # Increment patience counter on no improvement

    if no_improvement == patience:
        print(f"Early stopping after {i} epochs with no improvement in test accuracy")
        break  # Exit the training loop if patience is exhausted

    torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_final_MMA.pth")
    with open('results.txt', 'a') as f:
        f.write(str(i)+'_'+str(test_acc)+'\n')

### SFEW 2.0

In [None]:
emotion_to_index = {
            "angry": 0,
            "disgust": 1,
            "fear": 2,
            "happiness": 3,
            "sadness": 5,
            "surprise": 6,
            "neutral": 4
        }

index_to_emotion = {v: k for k, v in emotion_to_index.items()}

In [None]:
from torch.utils.data import Dataset

class SFEWDataset(Dataset):
    def __init__(self, root_dir, idx_test, idx_sfew, split="Train", transform=None):
        """
        Args:
            root_dir (str): Diretório raiz contendo as pastas train, test e valid.
            split (str): Qual partição carregar ("train", "test" ou "valid").
            transform (callable, optional): Transformações a serem aplicadas às imagens.
        """
        self.root_dir = os.path.join(root_dir, split, "Test_Aligned_Faces") if split == "Test" else os.path.join(root_dir, split)
        print(self.root_dir)
        self.transform = transform
        self.classes = sorted(os.listdir(self.root_dir))  # Lista de emoções
        self.idxs_test = idx_test
        self.idxs_sfew = idx_sfew
        
        self.samples = []
        for class_idx, class_name in enumerate(self.classes):
            class_path = os.path.join(self.root_dir, class_name)
            for filename in os.listdir(class_path):
                if filename.endswith(('.png', '.jpg', '.jpeg')):  # Filtra apenas imagens
                    self.samples.append((os.path.join(class_path, filename), class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")

        label = self.idxs_test[self.idxs_sfew[label]]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [None]:
root_dir = "/kaggle/input/datasetsfew"

In [None]:
train_dataset = SFEWDataset(root_dir=root_dir, idx_test=emotion_to_index, idx_sfew=index_to_emotion, split="Train", transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=True,
                                           num_workers=args.workers,
                                           pin_memory=True)

In [None]:
test_dataset = SFEWDataset(root_dir=root_dir, idx_test=emotion_to_index, idx_sfew=index_to_emotion, split="Val", transform=train_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
import matplotlib.pyplot as plt

# Obtendo o batch de imagens e rótulos
for images, labels in train_loader:
    # Se você quiser mostrar apenas um batch
    break

# Definindo o layout para 4 linhas e 8 colunas
fig, axes = plt.subplots(4, 8, figsize=(20, 10))  # 4x8 layout
axes = axes.flatten()  # Flatten para facilitar a iteração

# Loop para exibir as imagens no grid
for i, (img, label) in enumerate(zip(images, labels)):
    if i >= len(axes):  # Se houver mais imagens do que subgráficos
        break

    # Convertendo a imagem para numpy e normalizando
    img_np = img.permute(1, 2, 0).numpy()
    img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())

    # Exibindo a imagem
    ax = axes[i]
    ax.imshow(img_np)
    ax.axis('off')  # Desativar os eixos

    # Usando o mapa de rótulos para mostrar o nome da emoção
    label_name = index_to_emotion[label.item()]
    ax.set_title(f"{label_name}", fontsize=10)  # Título com o nome do label

# Ajustar o layout para não sobrepor as imagens
plt.tight_layout()
plt.show()

In [None]:
best_acc = 0
patience = 100  # Number of epochs to wait for improvement
no_improvement = 0

for i in range(1, args.epochs + 1):
    train_acc, train_loss = train(model, train_loader, optimizer, scheduler, device)
    test_acc, test_loss = test(model, test_loader, device)
    print('epoch: ', i, 'acc_test: ', test_acc, 'acc_train: ', train_acc)

    # Early stopping logic with patience
    if test_acc > best_acc:
        best_acc = test_acc
        no_improvement = 0  # Reset patience counter on improvement
        torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_best_SFEW.pth")
    else:
        no_improvement += 1  # Increment patience counter on no improvement

    if no_improvement == patience:
        print(f"Early stopping after {i} epochs with no improvement in test accuracy")
        break  # Exit the training loop if patience is exhausted

    torch.save({'model_state_dict': model.state_dict(),}, "/kaggle/working/ours_final_SFEW.pth")
    with open('results.txt', 'a') as f:
        f.write(str(i)+'_'+str(test_acc)+'\n')

### Teste entre dominios

In [None]:
emotion_to_index_mma = {
            "angry": 0,
            "disgust": 1,
            "fear": 2,
            "happiness": 3,
            "sadness": 5,
            "surprise": 6,
            "neutral": 4
        }

index_to_emotion_mma = {v: k for k, v in emotion_to_index.items()}

In [None]:
emotion_to_index_sfew = {
            "angry": 0,
            "disgust": 1,
            "fear": 2,
            "happiness": 3,
            "sadness": 5,
            "surprise": 6,
            "neutral": 4
        }

index_to_emotion_sfew = {v: k for k, v in emotion_to_index.items()}

In [None]:
emotion_to_index_aff = {
            "angry": 0,
            "disgust": 1,
            "fear": 2,
            "happiness": 3,
            "sadness": 4,
            "surprise": 5,
            "neutral": 6
        }

index_to_emotion_aff = {v: k for k, v in emotion_to_index.items()}

In [None]:
emotion_to_index_raf = {
            "surprise": 0,
            "fear": 1,
            "disgust": 2,
            "happiness": 3,
            "sadness": 4,
            "angry": 5,
            "neutral": 6
        }

index_to_emotion_raf = {v: k for k, v in emotion_to_index.items()}

In [None]:
emotion_to_index_fer = {
            "neutral": 0,
            "happiness": 1,
            "surprise": 2,
            "sadness": 3,
            "angry": 4,
            "disgust": 5,
            "fear": 6
        }

index_to_emotion_fer = {v: k for k, v in emotion_to_index_fer.items()}

### SFEW

In [None]:
root_dir = "/kaggle/input/datasetsfew"
dataset_sfew = SFEWDataset(root_dir=root_dir, idx_test=emotion_to_index_sfew, idx_sfew=index_to_emotion_sfew, split="Val", transform=train_transforms)
loader_sfew = torch.utils.data.DataLoader(dataset_sfew, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
root_dir = "/kaggle/input/mma-facial-expression/MMAFEDB"
dataset_mma = MMADataset(root_dir=root_dir, idx_test=emotion_to_index_sfew, idx_mma=index_to_emotion_mma, split="test", transform=eval_transforms)
loader_mma = torch.utils.data.DataLoader(dataset_mma, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
root_dir = "/kaggle/input/affectnetaligned/AffectNetCustom"
dataset_affect = AffectNetDataset(root_dir=root_dir, idx_test=emotion_to_index_sfew, idx_aff=index_to_emotion_aff, split="test", transform=eval_transforms)
loader_affect = torch.utils.data.DataLoader(dataset_affect, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
dataset_raf = RafDataSet('/kaggle/input/eacdata/raf-basic', idxs_test=emotion_to_index_sfew, idxs_raf=index_to_emotion_raf, train=False, transform=train_transforms)
loader_raf = torch.utils.data.DataLoader(dataset_raf,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=False,
                                           num_workers=args.workers,
                                           pin_memory=True)

In [None]:
root_dir = "/kaggle/input/ferplus/FER2013Plus"
dataset_fer = FERPlusDataset(root_dir=root_dir, idxs_test=emotion_to_index_sfew, idxs_fer=index_to_emotion_fer, subset="FER2013Test", transform=eval_transforms)
loader_fer = torch.utils.data.DataLoader(dataset_fer, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
def load_pretrained_weights(model, checkpoint):
    import collections
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    model_dict = model.state_dict()
    new_state_dict = collections.OrderedDict()
    matched_layers, discarded_layers = [], []
    for k, v in state_dict.items():
        # If the pretrained state_dict was saved as nn.DataParallel,
        # keys would contain "module.", which should be ignored.
        if k.startswith('module.'):
            k = k[7:]
        if k in model_dict and model_dict[k].size() == v.size():
            new_state_dict[k] = v
            matched_layers.append(k)
        else:
            discarded_layers.append(k)
    # new_state_dict.requires_grad = False
    model_dict.update(new_state_dict)

    model.load_state_dict(model_dict)
    print('load_weight', len(matched_layers))
    return model

In [None]:
checkpoint = torch.load('/kaggle/working/ours_best_SFEW.pth')
checkpoint = checkpoint["model_state_dict"]
model = load_pretrained_weights(model, checkpoint)
model.to(device)

In [None]:
###### SFEW
acc_sfew_sfew, test_loss = test(model, loader_sfew, device)
print('test acc dbtrain-sfew dbtest-sfew: ', acc_sfew_sfew)

###### FER
acc_sfew_fer, test_loss = test(model, loader_fer, device)
print('test acc dbtrain-sfew dbtest-fer: ', acc_sfew_fer)

###### AFFECT
acc_sfew_affect, test_loss = test(model, loader_affect, device)
print('test acc dbtrain-sfew dbtest-affect: ', acc_sfew_affect)

###### MMA
acc_sfew_mma, test_loss = test(model, loader_mma, device)
print('test acc dbtrain-sfew dbtest-mma: ', acc_sfew_mma)

### RAFDB
acc_sfew_raf, test_loss = test(model, loader_raf, device)
print('test acc dbtrain-sfew dbtest-raf: ', acc_sfew_raf)

### MMA

In [None]:
root_dir = "/kaggle/input/datasetsfew"
dataset_sfew = SFEWDataset(root_dir=root_dir, idx_test=emotion_to_index_mma, idx_sfew=index_to_emotion_sfew, split="Val", transform=train_transforms)
loader_sfew = torch.utils.data.DataLoader(dataset_sfew, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

root_dir = "/kaggle/input/mma-facial-expression/MMAFEDB"
dataset_mma = MMADataset(root_dir=root_dir, idx_test=emotion_to_index_mma, idx_mma=index_to_emotion_mma, split="test", transform=eval_transforms)
loader_mma = torch.utils.data.DataLoader(dataset_mma, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

root_dir = "/kaggle/input/affectnetaligned/AffectNetCustom"
dataset_affect = AffectNetDataset(root_dir=root_dir, idx_test=emotion_to_index_mma, idx_aff=index_to_emotion_aff, split="test", transform=eval_transforms)
loader_affect = torch.utils.data.DataLoader(dataset_affect, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

dataset_raf = RafDataSet('/kaggle/input/eacdata/raf-basic', idxs_test=emotion_to_index_mma, idxs_raf=index_to_emotion_raf, train=False, transform=train_transforms)
loader_raf = torch.utils.data.DataLoader(dataset_raf,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=False,
                                           num_workers=args.workers,
                                           pin_memory=True)

root_dir = "/kaggle/input/ferplus/FER2013Plus"
dataset_fer = FERPlusDataset(root_dir=root_dir, idxs_test=emotion_to_index_mma, idxs_fer=index_to_emotion_fer, subset="FER2013Test", transform=eval_transforms)
loader_fer = torch.utils.data.DataLoader(dataset_fer, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
checkpoint = torch.load('/kaggle/working/ours_best_MMA.pth')
checkpoint = checkpoint["model_state_dict"]
model = load_pretrained_weights(model, checkpoint)
model.to(device)

In [None]:
###### SFEW
acc_mma_sfew, test_loss = test(model, loader_sfew, device)
print('test acc dbtrain-mma dbtest-sfew: ', acc_mma_sfew)

###### FER
acc_mma_fer, test_loss = test(model, loader_fer, device)
print('test acc dbtrain-mma dbtest-fer: ', acc_mma_fer)

###### AFFECT
acc_mma_affect, test_loss = test(model, loader_affect, device)
print('test acc dbtrain-mma dbtest-affect: ', acc_mma_affect)

###### MMA
acc_mma_mma, test_loss = test(model, loader_mma, device)
print('test acc dbtrain-mma dbtest-mma: ', acc_mma_mma)

### RAFDB
acc_mma_raf, test_loss = test(model, loader_raf, device)
print('test acc dbtrain-mma dbtest-raf: ', acc_mma_raf)

### Raf

In [None]:
root_dir = "/kaggle/input/datasetsfew"
dataset_sfew = SFEWDataset(root_dir=root_dir, idx_test=emotion_to_index_raf, idx_sfew=index_to_emotion_sfew, split="Val", transform=train_transforms)
loader_sfew = torch.utils.data.DataLoader(dataset_sfew, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

root_dir = "/kaggle/input/mma-facial-expression/MMAFEDB"
dataset_mma = MMADataset(root_dir=root_dir, idx_test=emotion_to_index_raf, idx_mma=index_to_emotion_mma, split="test", transform=eval_transforms)
loader_mma = torch.utils.data.DataLoader(dataset_mma, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

root_dir = "/kaggle/input/affectnetaligned/AffectNetCustom"
dataset_affect = AffectNetDataset(root_dir=root_dir, idx_test=emotion_to_index_raf, idx_aff=index_to_emotion_aff, split="test", transform=eval_transforms)
loader_affect = torch.utils.data.DataLoader(dataset_affect, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

dataset_raf = RafDataSet('/kaggle/input/eacdata/raf-basic', idxs_test=emotion_to_index_raf, idxs_raf=index_to_emotion_raf, train=False, transform=train_transforms)
loader_raf = torch.utils.data.DataLoader(dataset_raf,
                                           batch_size=args.batch_size,
                                           #batch_size=1,
                                           shuffle=False,
                                           num_workers=args.workers,
                                           pin_memory=True)

root_dir = "/kaggle/input/ferplus/FER2013Plus"
dataset_fer = FERPlusDataset(root_dir=root_dir, idxs_test=emotion_to_index_raf, idxs_fer=index_to_emotion_fer, subset="FER2013Test", transform=eval_transforms)
loader_fer = torch.utils.data.DataLoader(dataset_fer, batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

In [None]:
checkpoint = torch.load('/kaggle/working/ours_best_RAFDB.pth')
checkpoint = checkpoint["model_state_dict"]
model = load_pretrained_weights(model, checkpoint)
model.to(device)

In [None]:
###### SFEW
acc_raf_sfew, test_loss = test(model, loader_sfew, device)
print('test acc dbtrain-raf dbtest-sfew: ', acc_raf_sfew)

###### FER
acc_raf_fer, test_loss = test(model, loader_fer, device)
print('test acc dbtrain-raf dbtest-fer: ', acc_raf_fer)

###### AFFECT
acc_raf_affect, test_loss = test(model, loader_affect, device)
print('test acc dbtrain-raf dbtest-affect: ', acc_raf_affect)

###### MMA
acc_raf_mma, test_loss = test(model, loader_mma, device)
print('test acc dbtrain-raf dbtest-mma: ', acc_raf_mma)

### RAFDB
acc_raf_raf, test_loss = test(model, loader_raf, device)
print('test acc dbtrain-raf dbtest-raf: ', acc_raf_raf)