In [None]:
import os
import subprocess
def check_mig_with_smi():
    try:
        output = subprocess.check_output(["nvidia-smi", "-L"], text=True)
        return "MIG" in output
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False

if check_mig_with_smi():
    print("Running on a GPU MIG instance")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
    print("Not running on a GPU MIG instance")

In [24]:
import torch

from modelComp.FluidGPT_B import FluidGPT_B

# load yaml file
yaml_path = "../conf/model/std-5.yaml"
import yaml


class DotDict(dict):
    def __init__(self, mapping=None):
        super().__init__()
        mapping = mapping or {} 
        for key, value in mapping.items():
            self[key] = DotDict(value) if isinstance(value, dict) else value

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"Key '{key}' not in config")

    def __setattr__(self, key, value):
        self[key] = value

def load_yaml_as_dotdict(filepath):
    with open(filepath, "r") as file:
        data = yaml.safe_load(file) or {}  # for if yaml empty lined
    return DotDict(data)

cm = load_yaml_as_dotdict(yaml_path)

In [25]:
model = FluidGPT_B(emb_dim=96,
                    data_dim=[64, cm.temporal_bundling, cm.in_channels, 128, 128],
                    patch_size=(cm.patch_size, cm.patch_size),
                    hiddenout_dim=cm.hiddenout_dim,
                    depth=cm.depth,
                    stage_depths=cm.stage_depths,
                    num_heads=cm.num_heads,
                    window_size=cm.window_size,
                    use_flex_attn=cm.use_flex_attn
                    ).cuda()
model.eval()
# Load the model weights
checkpoint_path = "C:/Users/20183172/Documents/2024-II/prep/epoch=28-step=563673.ckpt"
checkpoint = torch.load(checkpoint_path, map_location="cuda")
state_dict = checkpoint["state_dict"]
new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [4]:
x = torch.randn(16, 3, 2, 128, 128).cuda()
with torch.no_grad():
    y = model(x)