# Fine-tuning for Video Classification with 🤗 Transformers using ViViT
### Abstract for ViViT
We present pure-transformer based models for video classification, drawing upon the recent success of such models in image classification. Our model extracts spatio-temporal tokens from the input video, which are then encoded by a series of transformer layers. In order to handle the long sequences of tokens encountered in video, we propose several, efficient variants of our model which factorise the spatial- and temporal-dimensions of the input. Although transformer-based models are known to only be effective when large training datasets are available, we show how we can effectively regularise the model during training and leverage pretrained image models to be able to train on comparatively small datasets. We conduct thorough ablation studies, and achieve state-of-the-art results on multiple video classification benchmarks including Kinetics 400 and 600, Epic Kitchens, Something-Something v2 and Moments in Time, outperforming prior methods based on deep 3D convolutional networks. To facilitate further research, we release code at https://github.com/google-research/scenic/tree/main/scenic/projects/vivit

https://arxiv.org/pdf/2103.15691

## Embeddings
### Uniform frame sampling 
straightforward method of tokenising the input video is to uniformly sample nt frames from the input video clip, embed each 2D frame independently using the same method as ViT, and concatenate all these tokens together. Concretely, if nh · nw non-overlapping image patches are extracted from each frame, then a total of nt ·nh·nw tokens will be forwarded through the transformer encoder.Intuitively, this process may be seen as simply constructing a large 2D image to be tokenised following ViT

#### Tubelet embedding
An alternate method, to extract non-overlapping, spatio-temporal “tubes” from the input volume, and to linearly project this to Rd. This method is an extension of ViT’s embedding to 3D,and corresponds to a 3D convolution. 

### HF Vivit
https://huggingface.co/docs/transformers/main/model_doc/vivit

# Pretrained Dataset 
https://paperswithcode.com/dataset/kinetics-400-1

# Base Model

https://github.com/google-research/scenic/tree/main/scenic/projects/vivit

### google/vivit-f-16x2-kinetics400

##### https://huggingface.co/docs/transformers/main/model_doc/vivit

## Fine tune the model using the Duke Dataset


## Load the dataset artifacts from wandb

make sure you have created a [wandb](https://wandb.ai/site/) account. 

In [59]:
import wandb

In [60]:
# Initialize wandb
# import os
# wandb_key =  os.getenv("WANDB_API_KEY")
# wandb.login(key=wandb_key)
MODEL_NAME = "google/vivit-b-16x2-kinetics400"
DATASET = "duke iqm yolov5 filtered dataset"
PROJECT = "laryngeal_cancer_video_classification"

In [61]:
run = wandb.init(
    project=PROJECT,
    name="vivit-b-16x2-training-iqm-filtered", # runs process name
    tags=[MODEL_NAME, DATASET],
    entity="shaunliewsmu-singapore-management-university"
)


In [62]:
# Access the dataset through registry name
artifact = run.use_artifact('laryngeal_cancer_video_classification/laryngeal_dataset_iqm_filtered:v0')
dataset_dir = artifact.download()
print(f"\nSuccessfully downloaded dataset to {dataset_dir}")

# Print artifact metadata if available
if hasattr(artifact, 'metadata'):
    print("\nArtifact metadata:")
    for key, value in artifact.metadata.items():
        print(f"{key}: {value}")

[34m[1mwandb[0m: Downloading large artifact laryngeal_dataset_iqm_filtered:v0, 648.15MB. 133 files... 
[34m[1mwandb[0m:   133 of 133 files downloaded.  
Done. 0:0:0.7



Successfully downloaded dataset to /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_iqm_filtered:v0

Artifact metadata:
total_videos: 132
referral_ratio: 66.67%
total_referral: 88
dataset_structure: {'val': {'total': 20, 'referral': 14, 'non_referral': 6, 'referral_ratio': '70.00%', 'non_referral_ratio': '30.00%'}, 'test': {'total': 20, 'referral': 13, 'non_referral': 7, 'referral_ratio': '65.00%', 'non_referral_ratio': '35.00%'}, 'train': {'total': 92, 'referral': 61, 'non_referral': 31, 'referral_ratio': '66.30%', 'non_referral_ratio': '33.70%'}}
total_non_referral: 44


In [63]:
from transformers import TrainingArguments
from transformers import Trainer, TrainingArguments
from data_preprocessing import create_dataset
from data_handling import frames_convert_and_create_dataset_dictionary

In [64]:
from dotenv import load_dotenv
env_path =  ".env"
load_dotenv(env_path)

True

# Data Preprocessing

In [65]:
import model_configuration
from model_configuration import compute_metrics,collate_fn
import av
import numpy as np
from data_handling import sample_frame_indices, read_video_pyav

In [66]:
video_path = "/home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/test/referral/0057_processed.mp4"
container = av.open(video_path)

In [67]:
container.streams.video[0].frames

59

In [68]:
from moviepy import *

In [70]:
container = av.open(video_path)
indices = sample_frame_indices(clip_len=10, frame_sample_rate=2,seg_len=container.streams.video[0].frames)
video = read_video_pyav(container=container, indices=indices)

Reading frames from 26 to 45
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)


In [71]:
indices

array([26, 28, 30, 32, 34, 37, 39, 41, 43, 45])

In [72]:
video.shape

(10, 224, 224, 3)

In [73]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [74]:
path_files = "artifacts/laryngeal_dataset_iqm_filtered:v0"
video_dict, class_labels = frames_convert_and_create_dataset_dictionary(path_files)


Processing file artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0165_processed.mp4 number of Frames: 509
Reading frames from 136 to 145
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
Processing file artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0030_processed.mp4 number of Frames: 40
Reading frames from 29 to 38
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
Processing file artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0197_processed.mp4 number of Frames: 32
Reading frames from 4 to 13
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
Processing file artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0090_processed.mp4 number of Frames: 64
Reading frames from 32 to 41
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
Processing file artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0157_proces

In [75]:
len(video_dict)

95

In [76]:
video_dict[0].keys()

dict_keys(['video', 'labels', 'split', 'path'])

In [77]:
video_dict[0]['video'].shape

(10, 224, 224, 3)

In [78]:
video_dict[0]['labels']

'referral'

In [79]:
num_frames, height, width, channels =  video_dict[0]['video'].shape
num_frames, height, width, channels 

(10, 224, 224, 3)

# Display Video sample

In [80]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

def display_video_sample(video_sample):
    """
    Display a video sample using matplotlib animation
    """
    # Get frames from the sample
    frames = video_sample['video']
    
    # Create figure and axes
    fig, ax = plt.subplots(figsize=(8, 8))
    plt.close() # Prevents displaying the empty figure
    
    # Remove axes
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Display first frame
    img = ax.imshow(frames[0])
    
    def animate(i):
        img.set_array(frames[i])
        return (img,)
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, 
        animate, 
        frames=len(frames),
        interval=100,  # Time between frames in milliseconds
        blit=True
    )
    
    # Display the animation
    return HTML(anim.to_jshtml())

In [81]:
# Display a sample video
sample_video = video_dict[0]
print(f"Displaying video with label: {sample_video['labels']}")
display_video_sample(sample_video)

Displaying video with label: referral


In [82]:
class_labels = sorted(class_labels)
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

print(f"Unique classes: {list(label2id.keys())}.")

Unique classes: ['non_referral', 'referral'].


In [83]:
print("\nChecking input video format:")
first_video = video_dict[0]
print("Input video shape:", np.array(first_video['video']).shape)
print("Input video type:", type(first_video['video']))


Checking input video format:
Input video shape: (10, 224, 224, 3)
Input video type: <class 'numpy.ndarray'>


In [84]:
shuffled_dataset = create_dataset(video_dict)

Casting to class labels: 100%|██████████| 95/95 [00:00<00:00, 326.63 examples/s]
Processing videos: 100%|██████████| 95/95 [02:50<00:00,  1.80s/ examples]


First sample structure:
Keys: dict_keys(['labels', 'split', 'path', 'pixel_values'])
Pixel values type: <class 'list'>


In [85]:
shuffled_dataset['train'].features

{'labels': ClassLabel(names=['non_referral', 'referral'], id=None),
 'split': Value(dtype='string', id=None),
 'path': Value(dtype='string', id=None),
 'pixel_values': Sequence(feature=Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None)}

In [86]:
# Debug prints
print("\nChecking processed dataset:")
print("Dataset size:", len(shuffled_dataset['train']))

sample = shuffled_dataset['train'][0]
print("\nSample inspection:")
print("Sample keys:", sample.keys())
print("Sample pixel values type:", type(sample['pixel_values']))
if isinstance(sample['pixel_values'], torch.Tensor):
    print("Sample pixel values shape:", sample['pixel_values'].shape)
print("Sample label:", sample['labels'])

# Test batch creation
print("\nTesting batch creation:")
batch = collate_fn([shuffled_dataset['train'][i] for i in range(2)])
print("Batch pixel values shape:", batch['pixel_values'].shape)
print("Batch labels shape:", batch['labels'].shape)


Checking processed dataset:
Dataset size: 85

Sample inspection:
Sample keys: dict_keys(['labels', 'split', 'path', 'pixel_values'])
Sample pixel values type: <class 'list'>
Sample label: 1

Testing batch creation:
Batch pixel values shape: torch.Size([2, 10, 3, 224, 224])
Batch labels shape: torch.Size([2])


In [87]:
model = model_configuration.initialise_model(shuffled_dataset, device)

Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized because the shapes did not match:
- vivit.embeddings.position_embeddings: found shape torch.Size([1, 3137, 768]) in the checkpoint and torch.Size([1, 981, 768]) in the model instantiated
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [89]:
from datetime import datetime

training_output_dir = "./training_results"
run_name = f"vivit-duke-iqm-filtered-{datetime.now().strftime('%Y%m%d-%H%M')}"

training_args = TrainingArguments(
    output_dir=training_output_dir,         
    num_train_epochs=40,             
    per_device_train_batch_size=8,   
    per_device_eval_batch_size=8,    
    learning_rate=5e-05,            
    weight_decay=0.01,              
    logging_dir="./logs",           
    logging_steps=10,                
    seed=42,                       
    eval_strategy="steps", 
    eval_steps=10,                   
    warmup_steps=int(0.1 * 20),      
    optim="adamw_torch",          
    lr_scheduler_type="linear",      
    fp16=True,  
    report_to="wandb",
    run_name=run_name
)

In [90]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-05, betas=(0.9, 0.999), eps=1e-08)
# Define the trainer
trainer = Trainer(
    model=model,                      
    args=training_args,              
    train_dataset=shuffled_dataset["train"],      
    eval_dataset=shuffled_dataset["test"],       
    optimizers=(optimizer, None),  
    compute_metrics = compute_metrics,
    data_collator=collate_fn
)

In [91]:
with wandb.init(
    project=PROJECT, 
    job_type="train",
    name=run_name,  # Use same run name for consistency
    tags=[MODEL_NAME, DATASET],
    notes=f"Fine tuning {MODEL_NAME} with {DATASET}.",
    config={
        "model_name": MODEL_NAME,
        "dataset": DATASET,
        "batch_size": training_args.per_device_train_batch_size,
        "learning_rate": training_args.learning_rate,
        "num_epochs": training_args.num_train_epochs,
        "weight_decay": training_args.weight_decay
    }
) as run:
    train_results = trainer.train()



Step,Training Loss,Validation Loss,Accuracy
10,0.6398,1.313759,0.3
20,0.4303,1.54323,0.3
30,0.2084,2.42444,0.3
40,0.1023,2.596303,0.4
50,0.0244,4.233907,0.3
60,0.0025,4.721967,0.4
70,0.0001,5.144505,0.4
80,0.0,6.597664,0.5
90,0.0,10.422966,0.3
100,0.0,10.393098,0.3




0,1
eval/accuracy,▁▁▁▅▁▅▅█▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅
eval/loss,▁▁▂▂▃▄▄▅████████████████
eval/runtime,▃▄▃▂▁▂▁▁▄█▃▁▁▁▁▁▂▁▃▄▁▁▃▃
eval/samples_per_second,▅▅▅▇█▇██▅▁▆█████▇█▆▅▇█▆▆
eval/steps_per_second,▅▅▅▇█▇██▅▁▆█████▇█▆▅██▆▆
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/grad_norm,▄▆█▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,██▇▇▇▆▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▁▁
train/loss,█▆▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval/accuracy,0.4
eval/loss,9.81847
eval/runtime,7.4439
eval/samples_per_second,1.343
eval/steps_per_second,0.134
total_flos,2.67132927430656e+18
train/epoch,40.0
train/global_step,240.0
train/grad_norm,0.03036
train/learning_rate,0.0


In [92]:
trainer.save_model("model")
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** train metrics *****
  epoch                    =         40.0
  total_flos               = 2487869257GF
  train_loss               =       0.0587
  train_runtime            =   0:46:14.45
  train_samples_per_second =        1.225
  train_steps_per_second   =        0.087


In [93]:
custom_path = "./model"

In [94]:
TRAINED_MODEL_NAME = "Duke-ViViT-Fine-tuned-iqm-filtered-1"
with wandb.init(project=PROJECT,name="upload-duke-vivit-iqm-filtered-model-1", job_type="models"):
  artifact = wandb.Artifact(TRAINED_MODEL_NAME, type="model")
  artifact.add_dir(custom_path)
  wandb.save(custom_path)
  wandb.log_artifact(artifact)


[34m[1mwandb[0m: Adding directory to artifact (./model)... Done. 1.3s


# Inference

In [96]:
path_files_val = "/home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_iqm_filtered:v0/"
video_dict_val, class_labels_val = frames_convert_and_create_dataset_dictionary(path_files_val)

Processing file /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0165_processed.mp4 number of Frames: 509
Reading frames from 102 to 111
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
Processing file /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0030_processed.mp4 number of Frames: 40
Reading frames from 19 to 28
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
Processing file /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0197_processed.mp4 number of Frames: 32
Reading frames from 14 to 23
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
Processing file /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_iqm_filtered:v0/dataset/train/referral/0090_processed.mp4 number of

In [97]:
val_dataset = create_dataset(video_dict_val)

Casting to class labels: 100%|██████████| 95/95 [00:00<00:00, 351.91 examples/s]
Processing videos: 100%|██████████| 95/95 [02:51<00:00,  1.80s/ examples]


First sample structure:
Keys: dict_keys(['labels', 'split', 'path', 'pixel_values'])
Pixel values type: <class 'list'>


In [98]:
import wandb
MODEL_ARTIFACT_NAME = "laryngeal_cancer_video_classification/Duke-ViViT-Fine-tuned-iqm-filtered-1:v0"
run = wandb.init(
    project=PROJECT,
    name="duke-vivit-fine-tuned -iqm-filtered-inference-1", # runs process name
    tags=[MODEL_NAME, DATASET],
    entity="shaunliewsmu-singapore-management-university"
)
artifact = run.use_artifact(MODEL_ARTIFACT_NAME, type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Downloading large artifact Duke-ViViT-Fine-tuned-iqm-filtered-1:v0, 345.65MB. 4 files... 
[34m[1mwandb[0m:   4 of 4 files downloaded.  
Done. 0:0:0.7


In [99]:
artifact_dir

'/home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/Duke-ViViT-Fine-tuned-iqm-filtered-1:v0'

In [100]:
val_dataset

DatasetDict({
    train: Dataset({
        features: ['labels', 'split', 'path', 'pixel_values'],
        num_rows: 85
    })
    test: Dataset({
        features: ['labels', 'split', 'path', 'pixel_values'],
        num_rows: 10
    })
})

In [102]:
from data_handling import generate_all_files
import os
import numpy as np
import av
from pathlib import Path
def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    '''
    Sample a given number of frame indices from the video.
    Args:
        clip_len (`int`): Total number of frames to sample.
        frame_sample_rate (`int`): Sample every n-th frame.
        seg_len (`int`): Maximum allowed index of sample's last frame.
    Returns:
        indices (`List[int]`): List of sampled frame indices
    '''
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices

In [103]:
from transformers import VivitConfig
labels = val_dataset['train'].features['labels'].names
config = VivitConfig.from_pretrained(artifact_dir)
config.num_classes=len(labels)
config.id2label = {str(i): c for i, c in enumerate(labels)}
config.label2id = {c: str(i) for i, c in enumerate(labels)}
config.num_frames=10
config.video_size= [10, 224, 224]

In [104]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [105]:
from transformers import VivitImageProcessor, VivitForVideoClassification

In [106]:
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
fine_tune_model = VivitForVideoClassification.from_pretrained(artifact_dir,config=config)

In [107]:
from pathlib import Path
import os
import av
import torch
from data_handling import sample_frame_indices, read_video_pyav, generate_all_files


# inference the fine-tuned model using validation set

In [108]:
base_dir = Path("/home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_balanced:v0/dataset/val")


In [109]:
# Initialize empty lists
class_labels = []
true_labels = []
predictions = []
predictions_labels = []
all_videos = []
video_files = []
sizes = []
# Iterate through files
for file_path in generate_all_files(base_dir, only_files=True):
    # Only process video files
    if not str(file_path).endswith('.mp4'):
        continue
        
    try:
        # Get dataset split and class from path
        split = file_path.parent.parent.name  # train or test
        cls = file_path.parent.name  # class name
        
        # Process class labels
        if cls not in class_labels:
            class_labels.append(cls)
        
        true_labels.append(cls)
        
        print(f"Processing file: {file_path}")
        
        # Open and process video
        container = av.open(str(file_path))
        num_frames = container.streams.video[0].frames
        
        indices = sample_frame_indices(
            clip_len=10, 
            frame_sample_rate=1,
            seg_len=num_frames
        )
        
        video = read_video_pyav(container=container, indices=indices)
        inputs = image_processor(list(video), return_tensors="pt")
        
        # Get model predictions
        with torch.no_grad():
            outputs = fine_tune_model(**inputs)
            logits = outputs.logits
            
        predicted_label = logits.argmax(-1).item()
        prediction = fine_tune_model.config.id2label[str(predicted_label)]
        
        predictions.append(prediction)
        predictions_labels.append(predicted_label)
        
        print(f"File: {file_path.name}")
        print(f"True Label: {cls}")
        print(f"Predicted Label: {prediction}")
        print("-" * 50)
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        continue
    finally:
        if 'container' in locals():
            container.close()

print("\nProcessing complete!")
print(f"Total videos processed: {len(predictions)}")
print(f"Unique classes found: {len(class_labels)}")

Processing file: /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_balanced:v0/dataset/val/referral/0224.mp4
Reading frames from 763 to 772
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
File: 0224.mp4
True Label: referral
Predicted Label: referral
--------------------------------------------------
Processing file: /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_balanced:v0/dataset/val/referral/0220.mp4
Reading frames from 2235 to 2244
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
File: 0220.mp4
True Label: referral
Predicted Label: referral
--------------------------------------------------
Processing file: /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_balanced:v0/dataset/val/referral/0124.mp4
Reading frames from 379 to 388
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)
File: 0124.mp4
True Label: referral
Pred

In [110]:
from sklearn.metrics import classification_report

In [111]:
report = classification_report(true_labels, predictions)
print(report)

              precision    recall  f1-score   support

non_referral       0.67      0.67      0.67         6
    referral       0.86      0.86      0.86        14

    accuracy                           0.80        20
   macro avg       0.76      0.76      0.76        20
weighted avg       0.80      0.80      0.80        20



# inference using one video sample

In [112]:
file_name = "/home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_balanced:v0/dataset/val/referral/0039.mp4"
container = av.open(file_name)

In [113]:
indices = sample_frame_indices(clip_len=10, frame_sample_rate=3,seg_len=container.streams.video[0].frames)
print(f"Processing file {file_name} number of Frames: {container.streams.video[0].frames}")  
video = read_video_pyav(container=container, indices=indices)
inputs = image_processor(list(video), return_tensors="pt")

Processing file /home/shaunliew/ai-laryngeal-video-based-classifier/artifacts/laryngeal_dataset_balanced:v0/dataset/val/referral/0039.mp4 number of Frames: 600
Reading frames from 489 to 518
Number of frames extracted: 10
Final video array shape: (10, 224, 224, 3)


In [114]:

with torch.no_grad():
    outputs = fine_tune_model(**inputs)
    logits = outputs.logits

In [115]:
predicted_label = logits.argmax(-1).item()
prediction = fine_tune_model.config.id2label[str(predicted_label)]
prediction

'referral'

In [116]:
# close the wandb run
run.finish()