In [6]:
import warnings
warnings.filterwarnings("ignore")

import torch
import numpy as np
import torchvision
import matplotlib.pyplot as plt
import multiprocessing
from transformers import *
from datasets import load_dataset
from PIL import Image
from torchinfo import summary
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
import time
import timm

torch.cuda.empty_cache()
device="cuda" if torch.cuda.is_available() else "cpu"

## Importing Data

In [30]:
def transform(examples):
  inputs = image_processor([img.convert("RGB") for img in examples["image"]], return_tensors="pt")
  inputs["labels"] = examples["label"]
  return inputs

def collate_fn(batch):
  return {
      "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
      "labels": torch.tensor([x["labels"] for x in batch]),
  }

model_name = "google/vit-base-patch16-224"
batch_size = 16
cpu_count=multiprocessing.cpu_count()

image_processor = ViTImageProcessor.from_pretrained(model_name)

train_ds= load_dataset('../../../Desktop_SIH/SIH/Main_Dataset/')
train_ds = train_ds["train"].train_test_split(test_size=0.25) 


#train_ds= load_dataset('chest_xray/train/',num_proc=cpu_count)
#test_ds= load_dataset('chest_xray/test/',num_proc=cpu_count)

labels = train_ds["train"].features["label"].names
#dataset_train = train_ds.with_transform(transform)
#dataset_test=test_ds.with_transform(transform)
dataset = train_ds.with_transform(transform)


train_dataset_loader = torch.utils.data.DataLoader(dataset["train"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = torch.utils.data.DataLoader(dataset["test"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

#train_dataset_loader = torch.utils.data.DataLoader(dataset_train["train"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
#valid_dataset_loader = torch.utils.data.DataLoader(dataset_test["train"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

loading configuration file preprocessor_config.json from cache at /home/moose/.cache/huggingface/hub/models--google--vit-base-patch16-224/snapshots/3f49326eb077187dfe1c2a2bb15fbd74e6ab91e3/preprocessor_config.json
size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'shortest_edge', 'longest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.
Image processor ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}



Resolving data files:   0%|          | 0/14462 [00:00<?, ?it/s]

## Model Cr
eation

In [33]:
model_name="vit_base_patch16_384.orig_in21k_ft_in1k"

model = timm.create_model(model_name, pretrained=True)

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

In [54]:
model.blocks[len(model.blocks)-1].norm1

LayerNorm((768,), eps=1e-06, elementwise_affine=True)

In [55]:
model.blocks[len(model.blocks)-1]=torch.nn.Linear(in_features=768,out_features=768)

In [57]:
model.blocks[0]=torch.nn.Linear(in_features=768,out_features=768)

In [70]:
model.blocks[1].attn.qkv=torch.nn.Linear(in_features=768,out_features=1)

In [71]:
model.blocks

Sequential(
  (0): Linear(in_features=768, out_features=768, bias=True)
  (1): Block(
    (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (attn): Attention(
      (qkv): Linear(in_features=768, out_features=1, bias=True)
      (q_norm): Identity()
      (k_norm): Identity()
      (attn_drop): Dropout(p=0.0, inplace=False)
      (proj): Linear(in_features=768, out_features=768, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (ls1): Identity()
    (drop_path1): Identity()
    (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (act): GELU(approximate='none')
      (drop1): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (drop2): Dropout(p=0.0, inplace=False)
    )
    (ls2): Identity()
    (drop_path2): Identity()
  )
  (2): Block(
    (norm1): LayerNorm((768,), eps=1e-0

In [72]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=1, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop

In [99]:
dict(model.state_dict())["cls_token"]

tensor([[[ 5.5241e-03,  1.3605e-02, -2.5943e-01, -1.5176e-02,  4.0256e-01,
           4.6956e-02, -7.6992e-03,  1.1583e-02,  4.1794e-02, -2.2618e-01,
          -8.6113e-03, -1.3967e-02, -1.5123e-02, -1.3010e-02, -2.0492e-02,
          -8.5839e-03,  3.3414e-03,  5.7678e-02,  3.6365e-02, -5.6758e-03,
          -4.7974e-02,  5.9637e-03, -5.0221e-03, -1.3371e-02, -6.9871e-03,
           5.0253e-02,  7.3932e-03, -5.7437e-03,  2.1240e-02, -1.8942e-02,
           1.5181e-03,  2.5193e-02, -2.0468e-02,  1.6701e-02,  1.9216e-02,
          -9.3600e-04,  5.2248e-02,  9.3180e-03, -4.8390e-03,  1.8174e-03,
           1.4953e-02,  1.0364e-02,  1.7181e-02, -2.5523e-03,  6.5110e-02,
           4.3907e-01,  1.6787e-02,  4.5343e-03,  1.9626e-02, -1.3791e-03,
          -5.4258e-03, -4.4043e-02,  3.1062e-03, -1.2638e-02,  8.5818e-03,
          -7.1245e-03,  9.0131e-04, -2.6523e-03,  5.1280e-03, -7.6498e-02,
           6.5924e-04,  2.2964e-03,  7.7230e-03,  1.2007e-02,  4.1369e-02,
           1.1783e-02,  3