# BERT Fine-tuning with GRAFT

Setup notebook environment and imports.

In [1]:
# Setup notebook environment
import os
import sys
from pathlib import Path

# Get absolute path to project root
notebook_path = Path('.').resolve()
project_root = str(notebook_path.parent)

# Add project root to Python path
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    print(f"Added {project_root} to Python path")

# Verify imports will work
import decompositions
import grad_dist
print("Project modules found successfully!")

Added /home/ashishjv1/CODE/GRAFT-Main to Python path
Project modules found successfully!
Project modules found successfully!


In [2]:
# Install required dependencies
# !pip install -q sentence-transformers transformers datasets wandb eco2ai

In [3]:
# Core imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.optim.lr_scheduler as lr_scheduler
from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset
from datasets import load_dataset
# Project imports - using absolute imports
from decompositions import index_sel
from grad_dist import calnorm
from utils.model_mapper import ModelMapper
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
import tqdm
# Third party imports - with error handling
try:
    import wandb
    import eco2ai
    EXTERNAL_IMPORTS_OK = True
except ImportError as e:
    print(f"Warning: Some dependencies not found - {e}")
    print("Please run: pip install sentence-transformers transformers wandb eco2ai")
    EXTERNAL_IMPORTS_OK = False

# System utilities
import math
import itertools
import copy
import gc

print(f"Using project root: {project_root}")
print(f"Python path: {sys.path}")

  from .autonotebook import tqdm as notebook_tqdm
2025-05-26 18:55:48.887348: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-26 18:55:48.887393: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-26 18:55:48.888753: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-26 18:55:48.896208: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
202

Using project root: /home/ashishjv1/CODE/GRAFT-Main
Python path: ['/home/ashishjv1/CODE/GRAFT-Main', '/home/ashishjv1/miniconda3/envs/transformers/lib/python39.zip', '/home/ashishjv1/miniconda3/envs/transformers/lib/python3.9', '/home/ashishjv1/miniconda3/envs/transformers/lib/python3.9/lib-dynload', '', '/home/ashishjv1/miniconda3/envs/transformers/lib/python3.9/site-packages', '/home/ashishjv1/OPS_contrib/transformers/src', '/home/ashishjv1/miniconda3/envs/transformers/lib/python3.9/site-packages/setuptools/_vendor', '/tmp/tmpar__xstu']


In [4]:
# !pip install eco2ai

In [5]:
# download_dataset()

# df = load_dataset_into_to_dataframe()
# partition_dataset(df)

In [6]:
# !pip install transformers 

In [5]:
df_train = pd.read_csv("train.csv")
df_val = pd.read_csv("val.csv")
df_test = pd.read_csv("test.csv")

In [6]:
imdb_dataset = load_dataset(
    "csv",
    data_files={
        "train": "train.csv",
        "validation": "val.csv",
        "test": "test.csv",
    },
)

print(imdb_dataset)

DatasetDict({
    train: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 35000
    })
    validation: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 10000
    })
})


In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "distilbert-base-uncased",
)
print("Tokenizer input max length:", tokenizer.model_max_length)
print("Tokenizer vocabulary size:", tokenizer.vocab_size)



Tokenizer input max length: 512
Tokenizer vocabulary size: 30522


In [8]:
def tokenize_text(batch):
    return tokenizer(batch["text"], truncation=True, padding=True)

In [9]:
imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map: 100%|██████████| 5000/5000 [00:01<00:00, 3276.65 examples/s]
Map: 100%|██████████| 5000/5000 [00:01<00:00, 3276.65 examples/s]


In [10]:
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])

In [11]:
imdb_dataset

DatasetDict({
    train: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 35000
    })
    validation: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 10000
    })
})

In [12]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [13]:


class IMDBDataset(Dataset):
    def __init__(self, dataset_dict, partition_key="train"):
        self.partition = dataset_dict[partition_key]

    def __getitem__(self, index):
        return self.partition[index]

    def __len__(self):
        return self.partition.num_rows

In [14]:
train_dataset1 = IMDBDataset(imdb_dataset, partition_key="train")
val_dataset1 = IMDBDataset(imdb_dataset, partition_key="validation")
test_dataset1 = IMDBDataset(imdb_dataset, partition_key="test")

In [15]:
train_dataset = IMDBDataset(imdb_tokenized, partition_key="train")
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")

In [16]:
train_loader1 = DataLoader(
    dataset=train_dataset1,
    batch_size=100,
    shuffle=False, 
    num_workers=4
)

val_loader1 = DataLoader(
    dataset=val_dataset1,
    batch_size=100,
    num_workers=4
)

test_loader1 = DataLoader(
    dataset=test_dataset1,
    batch_size=100,
    num_workers=4
)

In [17]:


train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=100,
    shuffle=False, 
    num_workers=4
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=100,
    num_workers=4
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=100,
    num_workers=4
)

In [19]:


####Some Changes to rank_selection Method for BERT fine-tuning

def sample_selection(trainloader, data3, model, batch_size, fraction, sel_iter, numEpochs, device):


    indices = []
    l2 = []    
    len_ranks = batch_size * fraction

    ranks = np.arange(int(len_ranks - (len_ranks * fraction)), int((len_ranks + (len_ranks * fraction))), 1, dtype=int)
    num_selections = int(numEpochs / sel_iter)
    candidates = math.ceil(len(ranks) / num_selections)
        
    candidate_ranks = list(np.random.choice(list(ranks), size=candidates, replace=False))
    if len(candidate_ranks) > 3:
        candidate_ranks = list(np.random.choice(list(candidate_ranks), size=3, replace=False))
    print("current selected rank candidates:", candidate_ranks)
    
    
    
    for _, ((trainsamples), V) in enumerate(tqdm.tqdm(zip(trainloader, data3))):

    
        cached_state_dict = copy.deepcopy(model.state_dict())
        clone_dict = copy.deepcopy(model.state_dict())
        
#         net.load_state_dict(cached_state_dict)
        
        A = trainsamples["attention_mask"].T
        
        out = model(trainsamples, indices=None, last=True, freeze=True)
          
        loss = out["loss"]
        loss.backward()
        l_grad = model.model.classifier.weight.grad[0]
        l0_grad = copy.deepcopy(l_grad)
        model.zero_grad()
        

        distance_dict = {}

        for ranks in candidate_ranks:
            model.load_state_dict(cached_state_dict)
            
            idx2 = index_sel(V,  min(ranks, A.shape[1]))
            idx2 = list(set((itertools.chain(*idx2))))
            
            
            out_idx = model(trainsamples, indices=idx2, last=True, freeze=True)
            loss_idx = out_idx["loss"]
            loss_idx.backward()
            l0_idx_grad = model.model.classifier.weight.grad[0]
            distance = calnorm(l0_idx_grad, l0_grad)
            distance_dict[tuple(idx2)] = distance 

        
        indices.append(list(min(distance_dict, key=distance_dict.get)))


    del cached_state_dict
    del clone_dict
    del model
    torch.cuda.empty_cache()    
    gc.collect()
    
    l2 = indices[0]    
    for i in range(len(indices) - 1):
        l2 = l2 + list(np.array(l2[-1]) + np.array(indices[i + 1]))
    
    return l2

In [None]:
# sys.path.append('/beegfs/home/a.jha/DEIM_IS-Tests/')


arguments = type('', (), {'model': 'bert', 'numClasses': 2, 'device': 'cuda'})()
model_mapper = ModelMapper(arguments)
net = model_mapper.get_model()

In [None]:
fraction = 0.35
batch_size = 100
# sel_iter = 10
numEpochs = 30
model_name = "bert"
optimizer_name = "adam"
weight_decay = 0.0001
lr = 5e-5
grad_clip = 0.00
selection_iter = 10
device = "cuda"
sched = "cosine"

selection = 0
dataset_name = "IMDB"


wandb.login()
config = {"lr": lr, "batch_size": batch_size}
config.update({"architecture": f'{net}'})

warm_start = True

if optimizer_name.lower() == "adam":
    optimizer = optim.Adam(net.parameters(), lr=lr)
else:
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay = weight_decay)

scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

if warm_start:
    ttype = "warm"
else:
    ttype = "nowarm"

    
trn_losses = list()
val_losses = list()
trn_acc = list()
val_acc = list()

In [None]:
tracker = eco2ai.Tracker(
project_name=f"BERT_dset-IMDB_bs-{batch_size}", 
experiment_description="FineTune BERT",
file_name=f"emission_-{model_name}_dset-{dataset_name}_bs-{batch_size}_epochs-{numEpochs}_fraction-{fraction}_{optimizer_name}_{ttype}.csv"
)
tracker.start()

If you use a VPN, you may have problems with identifying your country by IP.
It is recommended to disable VPN or
manually install the ISO-Alpha-2 code of your country during initialization of the Tracker() class.
You can find the ISO-Alpha-2 code of your country here: https://www.iban.com/country-codes

  """


In [None]:
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")

device = "cuda"
decomp_type = "torch"

V_list = []
for idx, (batch) in enumerate(tqdm(train_loader1)):
    
    embeddings = sentence_model.encode(batch["text"])
#     print(embeddings.shape)
    outs = np.reshape(embeddings,(-1,embeddings.shape[0]))
    if decomp_type == "torch":
        U, S, Vt = torch.linalg.svd(torch.tensor(outs).to(device),full_matrices=False)
        Vt = Vt.detach().cpu().numpy()
    else:
        U, S, Vt = np.linalg.svd(outs,full_matrices=False)
        
    V_list.append(Vt)

100%|██████████| 350/350 [01:06<00:00,  5.29it/s]


In [None]:
# tracker.start()
for epoch in range(numEpochs):
#     before_lr = optimizer.param_groups[0]["lr"]
    net.train()
    if (epoch) % selection_iter == 0:
        if warm_start and selection == 0:
            trainloader = train_loader
            selection += 1
        else:
            train_model = net
            cached_state_dict = copy.deepcopy(train_model.state_dict())
#             clone_dict = copy.deepcopy(train_model.state_dict())
            indices = sample_selection(train_loader, V_list, train_model, batch_size, fraction, selection_iter, numEpochs, device)
#             indices = rank_selection(train_loader, V_list, net, batch_size, sel_iter, numEpochs, device="cuda")
            indices = [int(i) for i in indices]
            net.load_state_dict(cached_state_dict)

            selection += 1

            datasubset = torch.utils.data.Subset(train_dataset, indices)
            new_trainloader = torch.utils.data.DataLoader(datasubset, batch_size=batch_size,
                                            shuffle=True, pin_memory=False, num_workers=1)

            trainloader = new_trainloader
            
            ## Unfreeze all parameters for learning
            for param in net.parameters():
                param.requires_grad = True
            
            del train_model
    
    
    
    for idx, (batch) in enumerate(tqdm.tqdm(trainloader)):
        if wandb.run is None:
            name = f"wd{weight_decay}_opt{optimizer_name}_bs{batch_size}_gclip{grad_clip}_lr{lr}_f{fraction}_siter{selection_iter}"
            wandb.init(project=f"Smart_Sampling_{model_name}_{dataset_name}", config=config, name=name)
        outputs = net(batch, indices=None, last=False, freeze=False)
        
        logits = outputs["logits"]
        
        loss = outputs["loss"]

        
        optimizer.zero_grad()
        loss.backward()
        if grad_clip:
            nn.utils.clip_grad_value_(net.parameters(), grad_clip)
        optimizer.step()    

    
    if (epoch+1) % 1 == 0:
                trn_loss = 0
                trn_correct = 0
                trn_total = 0
                val_loss = 0
                val_correct = 0
                val_total = 0
                tst_correct = 0
                tst_total = 0
                tst_loss = 0
                curr_high = 0
                net.eval()
                with torch.no_grad():
                    for _, (batch) in enumerate(trainloader):
#                             inputs, targets = inputs.to(device), \
#                                               targets.to(device, non_blocking=True)
                            output = net(batch, indices=None, last=False, freeze=False)
#                             loss = torch.nn.functional.cross_entropy(outputs, targets)
                            trn_loss += output["loss"]
                            logits = output["logits"]
                            predicted = torch.argmax(logits, 1)
            
                            targets = batch["label"].to(device)
                
                            trn_total += targets.size(0)
                            trn_correct += predicted.eq(targets).sum().item()
                    trn_losses.append(trn_loss)
                    trn_acc.append(trn_correct / trn_total)
                    
                with torch.no_grad():        
                    for _, (batch) in enumerate(val_loader):
                            output = net(batch, indices=None, last=False, freeze=False)
                            val_loss += output["loss"]
                            logits = output["logits"]
                            predicted = torch.argmax(logits, 1)
                            targets = batch["label"].to(device)
                            val_total += targets.size(0)
                            val_correct += predicted.eq(targets).sum().item()
                    val_losses.append(val_loss)
                    val_acc.append(val_correct / val_total)

                if val_acc[-1] > curr_high:
                    curr_high = val_acc[-1]


                wandb.log({"Validation accuracy": curr_high, "Val Loss":val_losses[-1]/100,
                        "loss": trn_losses[-1]/100, "Train Accuracy": trn_acc[-1]*100, "Epoch": epoch})
                
                
                wandb.log({"loss": trn_losses[-1]/100, "Train Accuracy": trn_acc[-1]*100, "Epoch": epoch})      

                print("Epoch [{}/{}], Loss: {:.4f}, Train Accuracy: {:.2f}%".format(epoch+1, numEpochs,
                                                                                    trn_losses[-1],
                                                                                    trn_acc[-1]*100))

                print("Highest Accuracy:", curr_high)
                print("Validation Accuracy:", val_acc[-1])
                print("Validation Loss", val_losses[-1])
tracker.stop()  


  0%|          | 0/350 [00:00<?, ?it/s]

100%|██████████| 350/350 [06:11<00:00,  1.06s/it]


Epoch [1/30], Loss: 37.0972, Train Accuracy: 96.46%
Highest Accuracy: 0.9344
Validation Accuracy: 0.9344
Validation Loss tensor(8.9431, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [2/30], Loss: 20.6142, Train Accuracy: 98.08%
Highest Accuracy: 0.9308
Validation Accuracy: 0.9308
Validation Loss tensor(10.1698, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [3/30], Loss: 9.9796, Train Accuracy: 99.14%
Highest Accuracy: 0.9318
Validation Accuracy: 0.9318
Validation Loss tensor(11.2833, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [4/30], Loss: 9.0185, Train Accuracy: 99.20%
Highest Accuracy: 0.9284
Validation Accuracy: 0.9284
Validation Loss tensor(11.6260, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [5/30], Loss: 6.3079, Train Accuracy: 99.39%
Highest Accuracy: 0.927
Validation Accuracy: 0.927
Validation Loss tensor(14.1805, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [6/30], Loss: 4.5874, Train Accuracy: 99.58%
Highest Accuracy: 0.931
Validation Accuracy: 0.931
Validation Loss tensor(13.8740, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [7/30], Loss: 2.4680, Train Accuracy: 99.79%
Highest Accuracy: 0.931
Validation Accuracy: 0.931
Validation Loss tensor(14.1873, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [8/30], Loss: 2.6307, Train Accuracy: 99.83%
Highest Accuracy: 0.9296
Validation Accuracy: 0.9296
Validation Loss tensor(14.7058, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [9/30], Loss: 4.7977, Train Accuracy: 99.60%
Highest Accuracy: 0.9246
Validation Accuracy: 0.9246
Validation Loss tensor(19.2005, device='cuda:0')


100%|██████████| 350/350 [06:06<00:00,  1.05s/it]


Epoch [10/30], Loss: 1.9311, Train Accuracy: 99.82%
Highest Accuracy: 0.9284
Validation Accuracy: 0.9284
Validation Loss tensor(16.6825, device='cuda:0')
current selected rank candidates: [44, 25, 23]


350it [04:00,  1.45it/s]
100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [11/30], Loss: 0.1621, Train Accuracy: 99.95%
Highest Accuracy: 0.9322
Validation Accuracy: 0.9322
Validation Loss tensor(19.2316, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [12/30], Loss: 0.2363, Train Accuracy: 99.93%
Highest Accuracy: 0.93
Validation Accuracy: 0.93
Validation Loss tensor(19.1298, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [13/30], Loss: 0.0528, Train Accuracy: 99.99%
Highest Accuracy: 0.9322
Validation Accuracy: 0.9322
Validation Loss tensor(18.3071, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [14/30], Loss: 0.0442, Train Accuracy: 99.99%
Highest Accuracy: 0.9364
Validation Accuracy: 0.9364
Validation Loss tensor(20.3868, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [15/30], Loss: 0.0477, Train Accuracy: 100.00%
Highest Accuracy: 0.9306
Validation Accuracy: 0.9306
Validation Loss tensor(19.6749, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [16/30], Loss: 0.0087, Train Accuracy: 100.00%
Highest Accuracy: 0.9342
Validation Accuracy: 0.9342
Validation Loss tensor(19.2319, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [17/30], Loss: 0.0749, Train Accuracy: 99.99%
Highest Accuracy: 0.9298
Validation Accuracy: 0.9298
Validation Loss tensor(17.7693, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [18/30], Loss: 0.1542, Train Accuracy: 99.90%
Highest Accuracy: 0.9314
Validation Accuracy: 0.9314
Validation Loss tensor(20.5747, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [19/30], Loss: 0.0061, Train Accuracy: 100.00%
Highest Accuracy: 0.9328
Validation Accuracy: 0.9328
Validation Loss tensor(19.0781, device='cuda:0')


100%|██████████| 95/95 [01:38<00:00,  1.04s/it]


Epoch [20/30], Loss: 0.0080, Train Accuracy: 100.00%
Highest Accuracy: 0.9304
Validation Accuracy: 0.9304
Validation Loss tensor(22.0908, device='cuda:0')
current selected rank candidates: [43, 37, 29]


350it [04:21,  1.34it/s]
100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [21/30], Loss: 0.8001, Train Accuracy: 99.73%
Highest Accuracy: 0.925
Validation Accuracy: 0.925
Validation Loss tensor(19.0556, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [22/30], Loss: 0.0558, Train Accuracy: 99.98%
Highest Accuracy: 0.9324
Validation Accuracy: 0.9324
Validation Loss tensor(20.1032, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [23/30], Loss: 0.0476, Train Accuracy: 99.98%
Highest Accuracy: 0.9344
Validation Accuracy: 0.9344
Validation Loss tensor(18.3998, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [24/30], Loss: 0.0178, Train Accuracy: 100.00%
Highest Accuracy: 0.937
Validation Accuracy: 0.937
Validation Loss tensor(17.2147, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [25/30], Loss: 0.0600, Train Accuracy: 99.97%
Highest Accuracy: 0.9298
Validation Accuracy: 0.9298
Validation Loss tensor(21.8764, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [26/30], Loss: 0.0218, Train Accuracy: 100.00%
Highest Accuracy: 0.9366
Validation Accuracy: 0.9366
Validation Loss tensor(16.4206, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [27/30], Loss: 0.0450, Train Accuracy: 99.99%
Highest Accuracy: 0.9284
Validation Accuracy: 0.9284
Validation Loss tensor(19.5699, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [28/30], Loss: 0.3873, Train Accuracy: 99.96%
Highest Accuracy: 0.9322
Validation Accuracy: 0.9322
Validation Loss tensor(19.1495, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [29/30], Loss: 0.2201, Train Accuracy: 99.96%
Highest Accuracy: 0.9308
Validation Accuracy: 0.9308
Validation Loss tensor(19.0834, device='cuda:0')


100%|██████████| 117/117 [02:01<00:00,  1.04s/it]


Epoch [30/30], Loss: 0.0734, Train Accuracy: 99.98%
Highest Accuracy: 0.9336
Validation Accuracy: 0.9336
Validation Loss tensor(18.1433, device='cuda:0')


In [None]:
!nvidia-smi

Sun May 12 00:06:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:C1:00.0 Off |                    0 |
| N/A   45C    P0              89W / 400W |  35226MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                         