In [7]:
import torch
import pretrain.models_vit as models_vit

from thop import profile

model = models_vit.__dict__["sit_base"](
    num_classes=0, drop_path_rate=0.1
)
print(sum(p.numel() for p in model.parameters()))
flops, params = profile(model, inputs=(torch.randn(1, 1, 3 * 224 * 224), ))
print(flops)
print(params)

85798656
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
16862862336.0
85646592.0


In [6]:
import os

import numpy as np
import torch
from torch.utils.data import ConcatDataset
from torchvision.datasets.folder import DatasetFolder


def random_crop_resize(sample, crop_minlen, input_size):
    if sample.ndim == 1 or sample.shape[0] == 1:
        sample = np.vstack([sample, sample])
    sample_len = sample.shape[1]
    crop_size = np.random.randint(crop_minlen, sample_len + 1)
    start_idx = np.random.randint(0, sample_len - crop_size + 1)
    sample = sample[:, start_idx : start_idx + crop_size]
    if sample_len > input_size:
        sample = sample[:, :input_size]
    sample_padded = np.zeros((2, input_size), dtype=np.float32)
    sample_padded[:, : sample.shape[1]] = sample
    sample = sample_padded
    return sample


NPY_EXTENSIONS = ".npy"


class MyFolder(DatasetFolder):
    def __init__(self, root, mode, input_size, dataset_idx=0, domain_classnum=0):
        super().__init__(
            root=root,
            loader=np.load,
            extensions=NPY_EXTENSIONS,
        )
        self.mode = mode
        self.input_size = input_size
        self.dataset_idx = dataset_idx
        self.domain_classnum = domain_classnum

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        path, target = self.samples[index]

        sample = self.loader(path)
        if self.mode == "train":
            sample = random_crop_resize(sample, 512, self.input_size)
        else:
            sample = random_crop_resize(sample, len(sample), self.input_size)
        sample = sample.reshape((2, self.input_size))
        sample = torch.from_numpy(sample)
        target = self.dataset_idx * self.domain_classnum + target

        return sample, target
    
sample = np.random.rand(1024)
sample = random_crop_resize(sample, 512, 224*224*3)
print(sample.shape)
sample = sample.reshape((2, 224*224*3))

(2, 150528)
