Skip to content

Commit

Permalink
test: added unit test for getGPUSKU (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
smritidahal653 committed Apr 10, 2023
1 parent 582eac5 commit c88ad2e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pkg/provider/aci.go
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ func (p *ACIProvider) getGPUSKU(pod *v1.Pod) (azaciv2.GpuSKU, error) {
}
}

return "", fmt.Errorf("the pod requires GPU SKU %s, but ACI only supports SKUs %v in region %s", desiredSKU, p.region, p.gpuSKUs)
return "", fmt.Errorf("the pod requires GPU SKU %s, but ACI only supports SKUs %v in region %s", desiredSKU, p.gpuSKUs, p.region)
}

return p.gpuSKUs[0], nil
Expand Down
73 changes: 73 additions & 0 deletions pkg/provider/aci_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1785,5 +1785,78 @@ func TestGetContainerLogs(t *testing.T) {

})
}
}

func TestGetGPUSKU(t *testing.T) {
podName := "pod-" + uuid.New().String()
podNamespace := "ns-" + uuid.New().String()

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

aciMocks := createNewACIMock()

provider, err := createTestProvider(aciMocks, NewMockConfigMapLister(mockCtrl),
NewMockSecretLister(mockCtrl), NewMockPodLister(mockCtrl))
if err != nil {
t.Fatal("failed to create the test provider", err)
}

cases := []struct {
description string
gpuSkus []azaciv2.GpuSKU
desiredSku string
expectedError error
}{
{
description: "gpuTypeAnnotation is not set but ACI provides gpusku",
gpuSkus: []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100},
desiredSku: "",
expectedError: nil,
},
{
description: "gpuTypeAnnotation is set and the desired sku is supported by ACI",
gpuSkus: []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100},
desiredSku: "P100",
expectedError: nil,
},
{
description: "gpuTypeAnnotation is set but the desired sku is not supported by ACI",
gpuSkus: []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100},
desiredSku: "P120",
expectedError: fmt.Errorf("the pod requires GPU SKU P120, but ACI only supports SKUs %v in region %s", []azaciv2.GpuSKU{azaciv2.GpuSKUK80, azaciv2.GpuSKUP100}, provider.region),
},
{
description: "ACI doesn't provide any gpusku",
gpuSkus: []azaciv2.GpuSKU{},
desiredSku: "",
expectedError: fmt.Errorf("the pod requires GPU resource, but ACI doesn't provide GPU enabled container group in region %s", provider.region),
},
}

for _, tc := range cases {
t.Run(tc.description, func(t *testing.T) {
provider.gpuSKUs = tc.gpuSkus

pod := testsutil.CreatePodObj(podName, podNamespace)
if len(tc.desiredSku) > 0 {
pod.Annotations = map[string]string{}
pod.Annotations[gpuTypeAnnotation] = tc.desiredSku
}

gpuSKU, err := provider.getGPUSKU(pod)

if tc.expectedError != nil {
assert.Equal(t, err.Error(), tc.expectedError.Error(), "Error messages should match")
assert.Equal(t, string(gpuSKU), "", "No GPU SKU should be returned")
} else {
assert.NilError(t, err, "no error should be returned")
if len(tc.desiredSku) == 0 {
assert.Equal(t, gpuSKU, tc.gpuSkus[0], "Since no desired SKU was set, the first gpuSKU in the list should be returned")
} else {
assert.Equal(t, string(gpuSKU), tc.desiredSku, "Desired SKU should be returned")
}
}
})
}
}

0 comments on commit c88ad2e

Please sign in to comment.